wenlianghuang commited on
Commit
0d900a6
·
2 Parent(s): 485f02a979763a

Add main source

Browse files
.gitignore CHANGED
@@ -19,3 +19,12 @@ token.json
19
 
20
  *test_check*
21
  .DS_Store
 
 
 
 
 
 
 
 
 
 
19
 
20
  *test_check*
21
  .DS_Store
22
+
23
+ <<<<<<< HEAD
24
+ .cursor/*
25
+ chroma_db*/*
26
+ =======
27
+ chroma_db*/
28
+
29
+ .cursor/*
30
+ >>>>>>> 8862b07082bc878942f9e22816227a2e9a718b23
Deep_Agent_Gradio_RAG_localLLM_main.py CHANGED
@@ -25,7 +25,7 @@ def main():
25
  """主函數:初始化系統並啟動 Gradio 界面"""
26
  print("\n🚀 Deep Research Agent with RAG (Local MLX Edition) 啟動!")
27
  print("💡 本系統整合了:股票查詢、網路搜尋、PDF 知識庫查詢功能\n")
28
- print("📦 使用本地 MLX 模型,保護隱私,無需 API 金鑰\n")
29
 
30
  # 初始化 Parlant SDK
31
  print("🔧 正在初始化 Parlant SDK...")
 
25
  """主函數:初始化系統並啟動 Gradio 界面"""
26
  print("\n🚀 Deep Research Agent with RAG (Local MLX Edition) 啟動!")
27
  print("💡 本系統整合了:股票查詢、網路搜尋、PDF 知識庫查詢功能\n")
28
+
29
 
30
  # 初始化 Parlant SDK
31
  print("🔧 正在初始化 Parlant SDK...")
OLLAMA_SETUP.md ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ollama 設置指南
2
+
3
+ 本指南說明如何在 Deep Agentic AI Tool 中設置和使用 Ollama,特別是 Llama 3.2 3B 模型。
4
+
5
+ ## 📋 前置需求
6
+
7
+ - macOS 或 Linux 系統
8
+ - 至少 16GB 記憶體(推薦)
9
+ - Python >= 3.13
10
+
11
+ ## 🚀 安裝步驟
12
+
13
+ ### 1. 安裝 Ollama
14
+
15
+ **macOS:**
16
+ ```bash
17
+ brew install ollama
18
+ ```
19
+
20
+ 或從官網下載:https://ollama.com
21
+
22
+ **Linux:**
23
+ ```bash
24
+ curl -fsSL https://ollama.com/install.sh | sh
25
+ ```
26
+
27
+ ### 2. 下載 Llama 3.2 模型
28
+
29
+ ```bash
30
+ ollama pull llama3.2:3b
31
+ ```
32
+
33
+ 這會下載約 2GB 的模型文件。
34
+
35
+ ### 3. 啟動 Ollama 服務
36
+
37
+ Ollama 通常會自動啟動,如果需要手動啟動:
38
+
39
+ ```bash
40
+ ollama serve
41
+ ```
42
+
43
+ 服務預設運行在 `http://localhost:11434`
44
+
45
+ ### 4. 驗證安裝
46
+
47
+ 測試模型是否可用:
48
+
49
+ ```bash
50
+ ollama run llama3.2:3b "Hello, how are you?"
51
+ ```
52
+
53
+ ## ⚙️ 配置專案
54
+
55
+ ### 1. 更新環境變數
56
+
57
+ 在專案根目錄的 `.env` 文件中添加:
58
+
59
+ ```env
60
+ # 啟用 Ollama
61
+ USE_OLLAMA=true
62
+ OLLAMA_BASE_URL=http://localhost:11434
63
+ OLLAMA_MODEL=llama3.2:3b
64
+ ```
65
+
66
+ ### 2. 可選配置
67
+
68
+ 如果需要使用其他 Ollama 模型,可以修改:
69
+
70
+ ```env
71
+ OLLAMA_MODEL=qwen2.5:7b # 使用 Qwen2.5
72
+ OLLAMA_MODEL=llama3.1:8b # 使用 Llama 3.1
73
+ OLLAMA_MODEL=deepseek-r1:7b # 使用 DeepSeek-R1
74
+ OLLAMA_MODEL=mistral:7b # 使用 Mistral
75
+ ```
76
+
77
+ ## 🎯 使用方式
78
+
79
+ 系統會按照以下優先順序自動選擇 LLM:
80
+
81
+ 1. **Groq API**(如果配置了 `GROQ_API_KEY`)
82
+ 2. **Ollama**(如果 `USE_OLLAMA=true` 且服務可用)
83
+ 3. **MLX 模型**(備援選項)
84
+
85
+ 當 Groq API 額度用完時,系統會自動切換到 Ollama(如果啟用),否則使用 MLX 模型。
86
+
87
+ ## 🔍 檢查當前使用的模型
88
+
89
+ 啟動應用後,查看控制台輸出:
90
+
91
+ - `✅ 使用 Groq API (優先)` - 使用 Groq API
92
+ - `✅ 使用 Ollama 模型 (llama3.2:3b)` - 使用 Ollama
93
+ - `ℹ️ 使用本地 MLX 模型` - 使用 MLX 模型
94
+
95
+ ## 🐛 故障排除
96
+
97
+ ### Ollama 服務無法連接
98
+
99
+ **問題:** `⚠️ Ollama 初始化失敗: Connection refused`
100
+
101
+ **解決方案:**
102
+ 1. 確認 Ollama 服務正在運行:`ollama serve`
103
+ 2. 檢查端口是否被占用:`lsof -i :11434`
104
+ 3. 確認 `OLLAMA_BASE_URL` 配置正確
105
+
106
+ ### 模型找不到
107
+
108
+ **問題:** `⚠️ Ollama 初始化失敗: model not found`
109
+
110
+ **解決方案:**
111
+ ```bash
112
+ # 下載模型
113
+ ollama pull llama3.2:3b
114
+
115
+ # 列出已安裝的模型
116
+ ollama list
117
+ ```
118
+
119
+ ### 記憶體不足
120
+
121
+ **問題:** 系統運行緩慢或崩潰
122
+
123
+ **解決方案:**
124
+ - Llama 3.2:3B 需要約 2GB RAM
125
+ - 確保系統有足夠的可用記憶體(推薦至少 8GB)
126
+ - 這個模型已經很輕量,適合 16GB 記憶體的系統
127
+
128
+ ## 📊 模型比較
129
+
130
+ | 模型 | 大小 | 記憶體需求 | 特點 |
131
+ |------|------|-----------|------|
132
+ | llama3.2:3b | ~2GB | ~4GB | 輕量高效,適合 16GB 記憶體系統,Meta 開源 |
133
+ | deepseek-r1:7b | ~4.7GB | ~8GB | 優秀的推理能力,適合數學、編程 |
134
+ | qwen2.5:7b | ~4.5GB | ~8GB | 通用能力強,中英文支援好 |
135
+ | llama3.1:8b | ~4.6GB | ~8GB | Meta 開源,性能穩定 |
136
+ | mistral:7b | ~4.1GB | ~7GB | 速度快,效率高 |
137
+
138
+ ## 💡 性能優化建議
139
+
140
+ 1. **優先使用 Groq API**:如果可用,Groq API 速度最快
141
+ 2. **Ollama 作為備援**:當 Groq 不可用時,Ollama 提供良好的本地推理
142
+ 3. **MLX 作為最後備援**:在 Apple Silicon 上,MLX 模型有硬體優化
143
+
144
+ ## 📚 相關資源
145
+
146
+ - [Ollama 官方文檔](https://ollama.com/docs)
147
+ - [Llama 3.2 模型資訊](https://ollama.com/library/llama3.2)
148
+ - [LangChain Ollama 整合](https://python.langchain.com/docs/integrations/llms/ollama)
149
+
150
+ ---
151
+
152
+ **注意**:首次使用時,Ollama 會下載模型文件,這可能需要一些時間,請耐心等待。
153
+
PRIVATE_FILE_RAG_GUIDE.md ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 私有檔案 RAG 功能使用指南
2
+
3
+ ## 總覽
4
+
5
+ 本功能整合了 `Learn_RAG` 專案的 RAG 系統,讓使用者可以上傳私有檔案(如 PDF、DOCX、TXT),並基於這些檔案內容進行智慧問答。系統採用了先進的混合檢索與 LLM 技術,以提供準確、相關的回答。
6
+
7
+ ## 功能特色
8
+
9
+ - ✅ **支援多種檔案格式**:PDF、DOCX、DOC、TXT。
10
+ - ✅ **支援多檔案上傳**:可一次上傳並處理多個檔案。
11
+ - ✅ **混合檢索**:結合 BM25(關鍵字檢索)與向量檢索(語義檢索),大幅提升檢索準確度。
12
+ - ✅ **可選重排序**:使用 BGE Reranker 模型進一步優化檢索結果的相關性。
13
+ - ✅ **兩種分塊模式**:
14
+ - **語義分塊(推薦)**:智慧切分文件,確保語義完整性,不會在句子中間斷開,提升檢索品質。
15
+ - **字元分塊**:依固定字數切分,處理速度快,適合快速測試。
16
+ - ✅ **智慧回答生成**:
17
+ - **LLM 自動切換策略**:系統會自動根據可用性選擇最佳的 LLM,優先順序為:**Groq API > Ollama > MLX 本地模型**。
18
+ - **自動化提示工程**:自動偵測檔案類型(如學術論文、履歷、通用文件)以調整提問風格,生成更貼切的回答。
19
+ - ✅ **支援中英文問答**。
20
+
21
+ ## 使用方法
22
+
23
+ ### 1. 準備工作
24
+
25
+ 確保 `Learn_RAG` 專案與本專案 (`Deep_Agentic_AI_Tool`) 位於**同一個父目錄**下。正確的目錄結構應如下:
26
+
27
+ ```
28
+ /some_parent_directory/
29
+ ├─── Deep_Agentic_AI_Tool/ (本專案)
30
+ └─── Learn_RAG/
31
+ ```
32
+
33
+ 此外,請確保 `Learn_RAG` 的依賴已安裝。您可以在 `Learn_RAG` 目錄下執行 `uv sync` 或 `pip install -r requirements.txt`。
34
+
35
+ ### 2. 啟動系統
36
+
37
+ 執行主程式以啟動 Gradio 網頁介面:
38
+ ```bash
39
+ python Deep_Agent_Gradio_RAG_localLLM_main.py
40
+ ```
41
+
42
+ ### 3. 使用步驟
43
+
44
+ 1. **開啟介面**:在瀏覽器中開啟 Gradio 介面(預設為 `http://0.0.0.0:7860`),並點擊 **"📚 Private File RAG"** 標籤頁。
45
+
46
+ 2. **上傳檔案**:點擊或拖曳檔案至 **"📁 上傳檔案"** 區域,可選擇一個或多個 PDF、DOCX 或 TXT 檔案。
47
+
48
+ 3. **設定分塊模式**:
49
+ - **使用語義分塊(推薦)**:勾選此選項以獲得最佳的檢索品質。處理時間較長,但效果最好。
50
+ - **調整參數(可選)**:介面提供了對兩種分塊模式的進階參數調整,您可以根據需求調整,或直接使用已優化的預設值。
51
+
52
+ 4. **處理檔案**:點擊 **"📝 處理檔案"** 按鈕。系統會根據您的設定進行分塊、建立索引並初始化 RAG 系統。請等待處理狀態顯示 "✅ 文件處理完成"。
53
+ - **注意**:首次使用時,系統需要下載 Embedding 模型,可能需要數分鐘時間,請耐心等候。
54
+
55
+ 5. **提出問題**:在 **"❓ 請輸入您的問題"** 輸入框中,輸入您想詢問關於檔案內容的問題。
56
+
57
+ 6. **調整查詢選項**:
58
+ - **返回結果數量**:可調整檢索到的相關文件片段數量(預設為 3)。
59
+ - **使用 LLM 生成回答**:勾選此項,AI 會總結檢索到的內容並生成流暢的回答。若取消勾選,則僅顯示原始的文件片段。
60
+
61
+ 7. **執行查詢**:點擊 **"🔍 查詢"** 按鈕。
62
+
63
+ 8. **檢視結果**:
64
+ - **💬 AI 回答**:顯示由 LLM 生成的最終回答。
65
+ - **📄 檢索到的文件片段**:顯示用於生成回答的原始文件內容、來源及相關性分數,方便您查證。
66
+
67
+ 9. **清除與重置**:點擊 **"🗑️ 清除"** 按鈕可重設當前會話,讓您重新上傳檔案。
68
+
69
+ ## 技術細節
70
+
71
+ ### LLM 使用策略
72
+
73
+ 本系統採用彈性的 LLM 調度策略,無需手動設定:
74
+ 1. 🥇 **Groq API**:若您在環境變數中設定了 `GROQ_API_KEY`,系統會優先使用速度極快的 Groq API。
75
+ 2. 🥈 **Ollama**:若 Groq API 不可用或額度用盡,系統會自動切換至本地運行的 Ollama 模型(如 Llama3.2)。
76
+ 3. 🥉 **MLX**:若前兩者皆不可用,系統會使用 Apple MLX 在本地運行模型(如 Qwen2.5)作為最終備案。
77
+
78
+ ### 分塊模式詳解
79
+
80
+ - **字元分塊**(預設 `chunk_size: 500`, `chunk_overlap: 100`):
81
+ - **優點**:處理速度快。
82
+ - **適用場景**:快速測試、對語義完整性要求不高的文件。
83
+ - **參數說明**:
84
+ - `分塊大小`:每個區塊的字元數。較小值粒度更細,較大值上下文更完整。
85
+ - `分塊重疊`:相鄰區塊間重疊的字元數,有助於保持上下文連貫。
86
+
87
+ - **語義分塊**(預設 `threshold: 1.0`, `min_chunk_size: 100`):
88
+ - **優點**:根據語義邊界切分,能保持句子和段落的完整性,檢索品質更高。
89
+ - **適用場景**:專業文件、報告、論文等需要精準理解上下文的場景。
90
+ - **參數說明**:
91
+ - `語義分塊閾值`:控制分塊的敏感度。數值越小,分塊越細。建議值為 0.8-1.2。
92
+ - `最小分塊大小`:過���的區塊會被合併,以避免碎片化。
93
+
94
+ ### 檢索系統
95
+
96
+ 1. **Embedding 模型**:使用 `sentence-transformers/all-MiniLM-L6-v2` 將文字轉換為向量,此模型輕量且高效。
97
+ 2. **向量資料庫**:使用 `ChromaDB` 儲存並索引向量,資料庫會持久化儲存於 `./chroma_db_private` 目錄。
98
+ 3. **混合檢索**:結合 `BM25`(基於關鍵字)和 `向量檢索`(基於語義),並透過 `RRF` (Reciprocal Rank Fusion) 演算法融合結果,兼顧關鍵字匹配和語義相似度。
99
+ 4. **重排序器(Reranker)**:使用 `BAAI/bge-reranker-base` 模型對混合檢索的結果進行二次排序,將最相關的片段排在最前面,極大化提升了最終答案的品質。
100
+
101
+ ## 常見問題
102
+
103
+ ### Q: 處理檔案時出錯或提示 `Learn_RAG` 模組不可用?
104
+ **A:** 請檢查:
105
+ 1. **專案位置**:確保 `Learn_RAG` 專案目錄與 `Deep_Agentic_AI_Tool` 位於同一父目錄下。
106
+ 2. **依賴安裝**:確認您已安裝 `Learn_RAG` 的所有 Python 依賴。最簡單的方式是進入 `Learn_RAG` 目錄並執行 `uv sync`。
107
+ 3. **檔案格式**:確認您上傳的是支援的格式(PDF, DOCX, DOC, TXT)且檔案未損毀。
108
+
109
+ ### Q: AI 無法生成回答,或回答很慢?
110
+ **A:** 請檢查:
111
+ 1. **LLM 服務**:如果您想使用 Ollama,請確保 Ollama 服務正在本地運行(可透過終端機執行 `ollama serve`)。
112
+ 2. **模型下載**:確認 Ollama 需要的模型已經下載(如 `ollama pull llama3.2:3b`)。
113
+ 3. **LLM 狀態**:系統會自動從 Groq 切換至本地模型。若切換至 MLX,處理速度會較慢,請耐心等待。若不需生成式回答,可取消勾選 "使用 LLM 生成回答" 以直接查看檢索結果。
114
+
115
+ ### Q: "清除" 按鈕的功能是什麼?
116
+ **A:** "清除" 按鈕會重設當前 Gradio 會話中的 RAG 系統,清空已上傳的檔案和記憶體中的索引。這讓您可以重新上傳並處理新的一批檔案。
117
+ **注意**:此按鈕**不會**刪除磁碟上 `./chroma_db_private` 目錄中持久化的向量資料庫。若要完全清空所有資料,您需要手動刪除該目錄。
118
+
119
+ ### Q: 處理檔案速度很慢?
120
+ **A:**
121
+ 1. 首次使用時,系統需要下載數百 MB 的 Embedding 模型,請耐心等待。
122
+ 2. 語義分塊模式會進行複雜的計算,處理時間比字元分塊長,但效果更好。
123
+ 3. 檔案越大、數量越多,處理時間越長。
124
+
125
+ ## 依賴項
126
+
127
+ - `Learn_RAG` 專案及其所有依賴項。
128
+ - `langchain-community`, `sentence-transformers`, `chromadb`, `rank_bm25`, `pypdf`, `docx2txt` 等。
129
+ - `ollama` (若使用本地 Ollama LLM)。
README.md CHANGED
@@ -66,9 +66,29 @@
66
  # 使用 uv(推薦)
67
  uv sync
68
 
 
69
  # 或使用 pip
70
  pip install -e .
71
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  ### 2. 環境變數配置
74
 
@@ -220,6 +240,7 @@ Deep_Agentic_AI_Tool/
220
  3. 在 `get_tools_list()` 中添加工具
221
  4. 代理會自動發現並使用新工具
222
 
 
223
  ### 修改代理邏輯
224
 
225
  - **規劃邏輯**:編輯 `deep_agent_rag/agents/planner.py`
@@ -231,6 +252,41 @@ Deep_Agentic_AI_Tool/
231
  編輯 `deep_agent_rag/ui/gradio_interface.py` 修改 Web 界面。
232
 
233
  **詳細開發指南請參考:[系統架構](ARCHITECTURE.md#開發指南)**
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
  ## 📦 主要依賴
236
 
@@ -288,14 +344,75 @@ Deep_Agentic_AI_Tool/
288
 
289
  ## 📧 聯絡
290
 
 
291
  [添加聯絡資訊]
 
 
 
 
 
 
 
 
 
 
 
292
 
293
  ## 🙏 致謝
294
 
 
295
  - **LangChain & LangGraph**:優秀的代理框架
296
  - **MLX Team**:高效的本地模型推理
297
  - **Qwen Team**:Qwen2.5 模型
298
  - **Jina AI**:嵌入模型
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
  ---
301
 
 
66
  # 使用 uv(推薦)
67
  uv sync
68
 
69
+ <<<<<<< HEAD
70
  # 或使用 pip
71
  pip install -e .
72
  ```
73
+ =======
74
+ 3. **Set up environment variables** (create a `.env` file in the root directory):
75
+ ```env
76
+ # Optional: Groq API (for faster inference)
77
+ GROQ_API_KEY=your_groq_api_key_here
78
+
79
+ # Optional: Ollama (for local inference with Llama 3.2 or other models)
80
+ USE_OLLAMA=true
81
+ OLLAMA_BASE_URL=http://localhost:11434
82
+ OLLAMA_MODEL=llama3.2:3b
83
+
84
+ # Optional: Tavily API (for web search)
85
+ TAVILY_API_KEY=your_tavily_api_key_here
86
+
87
+ # Optional: Gmail API credentials
88
+ GMAIL_CREDENTIALS_FILE=credentials.json
89
+ GMAIL_TOKEN_FILE=token.json
90
+ ```
91
+ >>>>>>> 5beccbe9dfa0ef53e4123976ad54e2f1c28b72f8
92
 
93
  ### 2. 環境變數配置
94
 
 
240
  3. 在 `get_tools_list()` 中添加工具
241
  4. 代理會自動發現並使用新工具
242
 
243
+ <<<<<<< HEAD
244
  ### 修改代理邏輯
245
 
246
  - **規劃邏輯**:編輯 `deep_agent_rag/agents/planner.py`
 
252
  編輯 `deep_agent_rag/ui/gradio_interface.py` 修改 Web 界面。
253
 
254
  **詳細開發指南請參考:[系統架構](ARCHITECTURE.md#開發指南)**
255
+ =======
256
+ The system supports multiple LLM backends with automatic fallback (priority order):
257
+
258
+ 1. **Primary**: Groq API (fastest, requires API key)
259
+ - Model: `llama-3.3-70b-versatile`
260
+ - Automatically used if `GROQ_API_KEY` is set
261
+
262
+ 2. **Secondary**: Ollama (local inference, excellent reasoning capabilities)
263
+ - Default Model: `llama3.2:3b` (Llama 3.2 3B)
264
+ - Requires Ollama installed and model downloaded
265
+ - Enable with `USE_OLLAMA=true` in `.env`
266
+ - Lightweight and efficient, suitable for 16GB memory systems
267
+ - Automatically used when Groq API is unavailable or quota exhausted
268
+
269
+ 3. **Fallback**: Local MLX Model (privacy-preserving, no API key needed)
270
+ - Model: `mlx-community/Qwen2.5-Coder-7B-Instruct-4bit`
271
+ - Automatically used when both Groq API and Ollama are unavailable
272
+
273
+ The system automatically switches between backends based on availability.
274
+
275
+ **Setting up Ollama:**
276
+ ```bash
277
+ # Install Ollama (if not already installed)
278
+ # macOS: brew install ollama
279
+ # Or download from https://ollama.com
280
+
281
+ # Download Llama 3.2 model
282
+ ollama pull llama3.2:3b
283
+
284
+ # Start Ollama service (usually runs automatically)
285
+ ollama serve
286
+ ```
287
+
288
+ ## ⚙️ Configuration
289
+ >>>>>>> 5beccbe9dfa0ef53e4123976ad54e2f1c28b72f8
290
 
291
  ## 📦 主要依賴
292
 
 
344
 
345
  ## 📧 聯絡
346
 
347
+ <<<<<<< HEAD
348
  [添加聯絡資訊]
349
+ =======
350
+ - **LangChain**: Agent framework and tool integration
351
+ - **LangGraph**: Agent orchestration and workflow management
352
+ - **MLX/MLX-LM**: Local model inference (Apple Silicon optimized)
353
+ - **LangChain Ollama**: Ollama integration for local models
354
+ - **Gradio**: Web interface
355
+ - **ChromaDB**: Vector database for RAG
356
+ - **Tavily**: Web search API
357
+ - **yfinance**: Stock data retrieval
358
+ - **Google API Client**: Gmail API integration
359
+ >>>>>>> 5beccbe9dfa0ef53e4123976ad54e2f1c28b72f8
360
 
361
  ## 🙏 致謝
362
 
363
+ <<<<<<< HEAD
364
  - **LangChain & LangGraph**:優秀的代理框架
365
  - **MLX Team**:高效的本地模型推理
366
  - **Qwen Team**:Qwen2.5 模型
367
  - **Jina AI**:嵌入模型
368
+ =======
369
+ ### MLX Model Issues
370
+
371
+ - **Model not loading**: Ensure you have sufficient disk space and memory
372
+ - **Slow inference**: This is normal for local models. Consider using Groq API for faster results
373
+
374
+ ### Groq API Issues
375
+
376
+ - **Quota exhausted**: The system automatically falls back to Ollama (if enabled) or local MLX model
377
+ - **API errors**: Check your `GROQ_API_KEY` in `.env` file
378
+
379
+ ### Ollama Issues
380
+
381
+ - **Ollama not starting**: Ensure Ollama service is running (`ollama serve`)
382
+ - **Model not found**: Download the model first (`ollama pull llama3.2:3b`)
383
+ - **Connection errors**: Check `OLLAMA_BASE_URL` in `.env` (default: `http://localhost:11434`)
384
+ - **Memory issues**: Llama 3.2:3B requires ~2GB RAM, suitable for systems with 16GB memory
385
+
386
+ ### RAG System Issues
387
+
388
+ - **PDF not found**: Ensure the PDF file exists at the path specified in `config.py`
389
+ - **Embedding model errors**: The system will attempt to re-download the model if cache is corrupted
390
+
391
+ ### Gmail API Issues
392
+
393
+ - **Authorization errors**: Delete `token.json` and re-authorize
394
+ - **Credentials not found**: Ensure `credentials.json` is in the project root
395
+ - See `GMAIL_API_SETUP.md` for detailed setup instructions
396
+
397
+ ## 📝 License
398
+
399
+ [Add your license information here]
400
+
401
+ ## 🤝 Contributing
402
+
403
+ [Add contribution guidelines here]
404
+
405
+ ## 📧 Contact
406
+
407
+ [Add contact information here]
408
+
409
+ ## 🙏 Acknowledgments
410
+
411
+ - **LangChain & LangGraph**: For the excellent agent framework
412
+ - **MLX Team**: For efficient local model inference
413
+ - **Qwen Team**: For the Qwen2.5 model
414
+ - **Jina AI**: For the embedding model
415
+ >>>>>>> 5beccbe9dfa0ef53e4123976ad54e2f1c28b72f8
416
 
417
  ---
418
 
deep_agent_rag/config.py CHANGED
@@ -49,6 +49,13 @@ GROQ_MAX_TOKENS = 2048
49
  GROQ_TEMPERATURE = 0.7
50
  USE_GROQ_FIRST = True # 是否优先使用 Groq API
51
 
 
 
 
 
 
 
 
52
  # Email 配置 - 使用 Gmail API
53
  EMAIL_SENDER = "matthuang46@gmail.com"
54
  # Gmail API 配置
 
49
  GROQ_TEMPERATURE = 0.7
50
  USE_GROQ_FIRST = True # 是否优先使用 Groq API
51
 
52
+ # Ollama 配置
53
+ OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
54
+ OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", "llama3.2:3b") # Llama 3.2 3B
55
+ OLLAMA_MAX_TOKENS = 2048
56
+ OLLAMA_TEMPERATURE = 0.7
57
+ USE_OLLAMA = os.getenv("USE_OLLAMA", "false").lower() == "true" # 是否啟用 Ollama
58
+
59
  # Email 配置 - 使用 Gmail API
60
  EMAIL_SENDER = "matthuang46@gmail.com"
61
  # Gmail API 配置
deep_agent_rag/rag/adaptive_rag_selector.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 自适应 RAG 方法选择器
3
+ 根据查询和文件特征自动选择最佳的 RAG 方法
4
+ """
5
+ from typing import Dict, List, Optional
6
+ from enum import Enum
7
+ import re
8
+ import logging
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class RAGMethod(Enum):
14
+ """可用的 RAG 方法"""
15
+ BASIC = "basic" # 基础 RAG(当前使用的)
16
+ SUBQUERY = "subquery" # 子查询分解
17
+ HYDE = "hyde" # 假设文档嵌入
18
+ STEP_BACK = "step_back" # 后退推理
19
+ HYBRID_SUBQUERY_HYDE = "hybrid_subquery_hyde" # 混合子查询+HyDE
20
+ TRIPLE_HYBRID = "triple_hybrid" # 三重混合
21
+
22
+
23
+ class QueryComplexity(Enum):
24
+ """查询复杂度"""
25
+ SIMPLE = "simple" # 简单查询(单问题,短句)
26
+ MODERATE = "moderate" # 中等复杂度(包含多个概念)
27
+ COMPLEX = "complex" # 复杂查询(多部分问题,需要分解)
28
+ VERY_COMPLEX = "very_complex" # 非常复杂(多个相关问题)
29
+
30
+
31
+ class QueryType(Enum):
32
+ """查询类型"""
33
+ FACTUAL = "factual" # 事实性查询("什么是X")
34
+ CONCEPTUAL = "conceptual" # 概念性查询("如何理解X")
35
+ COMPARATIVE = "comparative" # 比较性查询("X和Y的区别")
36
+ PRINCIPLE = "principle" # 原理性查询("X的工作原理")
37
+ MULTI_ASPECT = "multi_aspect" # 多面向查询(包含多个问题)
38
+
39
+
40
+ class AdaptiveRAGSelector:
41
+ """
42
+ 自适应 RAG 方法选择器
43
+
44
+ 根据以下特征选择最佳 RAG 方法:
45
+ 1. 查询复杂度
46
+ 2. 查询类型
47
+ 3. 文件数量和类型
48
+ 4. 文档复杂度
49
+ """
50
+
51
+ def __init__(self):
52
+ """初始化选择器"""
53
+ pass
54
+
55
+ def analyze_query(self, query: str) -> Dict:
56
+ """
57
+ 分析查询特征
58
+
59
+ Args:
60
+ query: 用户查询问题
61
+
62
+ Returns:
63
+ 包含查询特征的字典
64
+ """
65
+ query_lower = query.lower()
66
+ query_len = len(query)
67
+ word_count = len(query.split())
68
+
69
+ # 检测查询复杂度
70
+ complexity = self._detect_complexity(query, word_count)
71
+
72
+ # 检测查询类型
73
+ query_type = self._detect_query_type(query, query_lower)
74
+
75
+ # 检测是否包含多个问题
76
+ question_count = query.count('?') + query.count('?')
77
+ has_multiple_questions = question_count > 1
78
+
79
+ # 检测是否包含比较性词汇
80
+ comparison_keywords = ['vs', 'versus', 'difference', '区别', '比较', 'compare', '对比', '和', 'and', '与']
81
+ is_comparative = any(kw in query_lower for kw in comparison_keywords)
82
+
83
+ # 检测是否包含专业术语
84
+ technical_indicators = [
85
+ '原理', 'mechanism', 'algorithm', 'architecture', 'model', 'system',
86
+ '原理', '机制', '算法', '架构', '模型', '系统', '方法', 'method',
87
+ '如何工作', 'how does', 'how do', 'work', 'function'
88
+ ]
89
+ has_technical_terms = any(ind in query_lower for ind in technical_indicators)
90
+
91
+ # 检测是否包含"为什么"、"如何"等需要解释的词汇
92
+ explanation_keywords = ['为什么', 'why', '如何', 'how', 'explain', '解释', '说明']
93
+ needs_explanation = any(kw in query_lower for kw in explanation_keywords)
94
+
95
+ return {
96
+ 'complexity': complexity,
97
+ 'type': query_type,
98
+ 'word_count': word_count,
99
+ 'length': query_len,
100
+ 'has_multiple_questions': has_multiple_questions,
101
+ 'is_comparative': is_comparative,
102
+ 'has_technical_terms': has_technical_terms,
103
+ 'needs_explanation': needs_explanation,
104
+ 'question_count': question_count
105
+ }
106
+
107
+ def _detect_complexity(self, query: str, word_count: int) -> QueryComplexity:
108
+ """检测查询复杂度"""
109
+ # 简单查询:短句,单问题
110
+ if word_count <= 10 and query.count('?') + query.count('?') <= 1:
111
+ return QueryComplexity.SIMPLE
112
+
113
+ # 中等复杂度:中等长度,可能包含多个概念
114
+ if word_count <= 25:
115
+ return QueryComplexity.MODERATE
116
+
117
+ # 复杂查询:长句,多个问题或概念
118
+ if word_count <= 50:
119
+ return QueryComplexity.COMPLEX
120
+
121
+ # 非常复杂:很长,多个问题
122
+ return QueryComplexity.VERY_COMPLEX
123
+
124
+ def _detect_query_type(self, query: str, query_lower: str) -> QueryType:
125
+ """检测查询类型"""
126
+ # 比较性查询
127
+ if any(kw in query_lower for kw in ['vs', 'versus', 'difference', '区别', '比较', 'compare', '对比', '和', 'and', '与']):
128
+ return QueryType.COMPARATIVE
129
+
130
+ # 原理性查询
131
+ if any(kw in query_lower for kw in ['原理', 'principle', 'how does', 'how do', 'mechanism', '如何工作', '工作原理']):
132
+ return QueryType.PRINCIPLE
133
+
134
+ # 概念性查询
135
+ if any(kw in query_lower for kw in ['什么是', 'what is', '理解', 'understand', 'explain', '解释']):
136
+ return QueryType.CONCEPTUAL
137
+
138
+ # 多面向查询
139
+ if query.count('?') + query.count('?') > 1:
140
+ return QueryType.MULTI_ASPECT
141
+
142
+ # 默认:事实性查询
143
+ return QueryType.FACTUAL
144
+
145
+ def analyze_files(self, file_paths: List[str], documents: Optional[List[Dict]] = None) -> Dict:
146
+ """
147
+ 分析文件特征
148
+
149
+ Args:
150
+ file_paths: 文件路径列表
151
+ documents: 文档列表(可选,如果已处理)
152
+
153
+ Returns:
154
+ 包含文件特征的字典
155
+ """
156
+ file_count = len(file_paths)
157
+
158
+ # 检测文件类型
159
+ file_types = []
160
+ for path in file_paths:
161
+ if path.endswith('.pdf'):
162
+ file_types.append('pdf')
163
+ elif path.endswith(('.docx', '.doc')):
164
+ file_types.append('docx')
165
+ else:
166
+ file_types.append('txt')
167
+
168
+ # 分析文档复杂度(如果有文档)
169
+ total_chunks = len(documents) if documents else 0
170
+ avg_chunk_size = 0
171
+ if documents:
172
+ chunk_sizes = [len(doc.get('content', '')) for doc in documents]
173
+ avg_chunk_size = sum(chunk_sizes) / len(chunk_sizes) if chunk_sizes else 0
174
+
175
+ # 检测是否为学术论文(基于文件名或内容)
176
+ is_academic = any('paper' in path.lower() or 'arxiv' in path.lower() or
177
+ path.endswith('.pdf') for path in file_paths)
178
+
179
+ return {
180
+ 'file_count': file_count,
181
+ 'file_types': file_types,
182
+ 'total_chunks': total_chunks,
183
+ 'avg_chunk_size': avg_chunk_size,
184
+ 'is_academic': is_academic,
185
+ 'is_single_file': file_count == 1,
186
+ 'is_multi_file': file_count > 1
187
+ }
188
+
189
+ def select_best_method(
190
+ self,
191
+ query_features: Dict,
192
+ file_features: Dict,
193
+ enable_advanced: bool = True
194
+ ) -> RAGMethod:
195
+ """
196
+ 根据特征选择最佳 RAG 方法
197
+
198
+ Args:
199
+ query_features: 查询特征(来自 analyze_query)
200
+ file_features: 文件特征(来自 analyze_files)
201
+ enable_advanced: 是否启用高级方法(如果 False,只使用基础方法)
202
+
203
+ Returns:
204
+ 选择的 RAG 方法
205
+ """
206
+ if not enable_advanced:
207
+ return RAGMethod.BASIC
208
+
209
+ complexity = query_features['complexity']
210
+ query_type = query_features['type']
211
+ has_multiple_questions = query_features['has_multiple_questions']
212
+ is_comparative = query_features['is_comparative']
213
+ has_technical_terms = query_features['has_technical_terms']
214
+ needs_explanation = query_features['needs_explanation']
215
+ file_count = file_features['file_count']
216
+ is_multi_file = file_features['is_multi_file']
217
+
218
+ # 决策树
219
+
220
+ # 1. 非常复杂的查询 + 多文件 → Triple Hybrid(最强)
221
+ if complexity == QueryComplexity.VERY_COMPLEX and is_multi_file:
222
+ return RAGMethod.TRIPLE_HYBRID
223
+
224
+ # 2. 复杂查询 + 多问题 → SubQuery 或 Hybrid Subquery+HyDE
225
+ if complexity in [QueryComplexity.COMPLEX, QueryComplexity.VERY_COMPLEX]:
226
+ if has_multiple_questions or query_type == QueryType.MULTI_ASPECT:
227
+ if is_multi_file:
228
+ return RAGMethod.HYBRID_SUBQUERY_HYDE
229
+ else:
230
+ return RAGMethod.SUBQUERY
231
+
232
+ # 3. 原理性查询 → Step-back(需要背景知识)
233
+ if query_type == QueryType.PRINCIPLE:
234
+ if complexity in [QueryComplexity.MODERATE, QueryComplexity.COMPLEX]:
235
+ return RAGMethod.STEP_BACK
236
+
237
+ # 4. 专业术语查询 → HyDE(生成假设文档)
238
+ if has_technical_terms and complexity in [QueryComplexity.MODERATE, QueryComplexity.COMPLEX]:
239
+ return RAGMethod.HYDE
240
+
241
+ # 5. 比较性查询 + 多文件 → SubQuery(需要分别检索)
242
+ if is_comparative and is_multi_file:
243
+ return RAGMethod.SUBQUERY
244
+
245
+ # 6. 中等复杂度 + 多文件 → Hybrid Subquery+HyDE
246
+ if complexity == QueryComplexity.MODERATE and is_multi_file:
247
+ return RAGMethod.HYBRID_SUBQUERY_HYDE
248
+
249
+ # 7. 简单查询 → 基础 RAG 或 HyDE
250
+ if complexity == QueryComplexity.SIMPLE:
251
+ if has_technical_terms:
252
+ return RAGMethod.HYDE
253
+ else:
254
+ return RAGMethod.BASIC
255
+
256
+ # 8. 需要解释的查询 → Step-back(提供背景知识)
257
+ if needs_explanation and complexity == QueryComplexity.MODERATE:
258
+ return RAGMethod.STEP_BACK
259
+
260
+ # 9. 默认:中等复杂度使用 Step-back
261
+ if complexity == QueryComplexity.MODERATE:
262
+ return RAGMethod.STEP_BACK
263
+
264
+ # 10. 默认:复杂查询使用 SubQuery
265
+ return RAGMethod.SUBQUERY
266
+
267
+ def get_method_reason(self, method: RAGMethod, query_features: Dict, file_features: Dict) -> str:
268
+ """
269
+ 获取选择该方法的理由
270
+
271
+ Args:
272
+ method: 选择的 RAG 方法
273
+ query_features: 查询特征
274
+ file_features: 文件特征
275
+
276
+ Returns:
277
+ 选择理由的字符串
278
+ """
279
+ complexity = query_features['complexity'].value
280
+ query_type = query_features['type'].value
281
+ file_count = file_features['file_count']
282
+
283
+ reasons = {
284
+ RAGMethod.BASIC: f"简单查询({complexity}),使用基础 RAG 方法即可",
285
+ RAGMethod.SUBQUERY: f"查询包含多个方面({query_features['question_count']}个问题,{complexity}),使用子查询分解以全面检索",
286
+ RAGMethod.HYDE: f"查询包含专业术语({complexity}),使用假设文档嵌入以改善语义检索",
287
+ RAGMethod.STEP_BACK: f"原理性查询({query_type},{complexity}),使用后退推理获取背景知识和原理",
288
+ RAGMethod.HYBRID_SUBQUERY_HYDE: f"复杂查询({complexity})+ {file_count}个文件,使用混合子查询+HyDE方法以全面检索",
289
+ RAGMethod.TRIPLE_HYBRID: f"非常复杂的查询({complexity})+ {file_count}个文件,使用三重混合方法(SubQuery+HyDE+Step-back)以获得最佳效果"
290
+ }
291
+ return reasons.get(method, f"使用 {method.value} 方法")
292
+
deep_agent_rag/rag/llm_adapter.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM 适配器:将 LangChain ChatModel 包装成 OllamaLLM 接口
3
+ 用于兼容 Learn_RAG 项目中的进阶 RAG 方法
4
+ """
5
+ from typing import Optional
6
+ from langchain_core.messages import HumanMessage
7
+ from langchain_core.language_models.chat_models import BaseChatModel
8
+ import logging
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class LangChainLLMAdapter:
14
+ """
15
+ 将 LangChain ChatModel 适配为 OllamaLLM 接口
16
+
17
+ 这个适配器允许 Learn_RAG 项目中的进阶 RAG 方法(需要 OllamaLLM)
18
+ 使用 Deep_Agentic_AI_Tool 的统一 LLM 系统(Groq -> Ollama -> MLX)
19
+ """
20
+
21
+ def __init__(self, langchain_llm: BaseChatModel):
22
+ """
23
+ 初始化适配器
24
+
25
+ Args:
26
+ langchain_llm: LangChain ChatModel 实例(来自 get_llm())
27
+ """
28
+ self.llm = langchain_llm
29
+ self.model_name = self._detect_model_name()
30
+ self.base_url = "http://localhost:11434" # 默认值,实际不使用
31
+ self.timeout = 120 # 默认值,实际不使用
32
+
33
+ logger.info(f"✅ LLM 适配器初始化完成 (模型类型: {self.model_name})")
34
+
35
+ def _detect_model_name(self) -> str:
36
+ """
37
+ 检测 LLM 类型和模型名称
38
+
39
+ Returns:
40
+ 模型名称字符串
41
+ """
42
+ llm_type = type(self.llm).__name__
43
+
44
+ # 检测 Groq
45
+ if "Groq" in llm_type or "ChatGroq" in llm_type:
46
+ model_name = getattr(self.llm, 'model_name', 'groq-unknown')
47
+ return f"groq:{model_name}"
48
+
49
+ # 检测 Ollama
50
+ if "Ollama" in llm_type or "ChatOllama" in llm_type:
51
+ model_name = getattr(self.llm, 'model', 'ollama-unknown')
52
+ return f"ollama:{model_name}"
53
+
54
+ # 检测 MLX
55
+ if "MLX" in llm_type or "MLXChatModel" in llm_type:
56
+ return "mlx:qwen2.5"
57
+
58
+ # 默认
59
+ return f"langchain:{llm_type}"
60
+
61
+ def _check_ollama_connection(self) -> bool:
62
+ """
63
+ 检查 Ollama 服务是否可用(兼容性方法,实际不使用)
64
+
65
+ Returns:
66
+ 总是返回 True(因为我们使用的是统一的 LLM 系统)
67
+ """
68
+ return True
69
+
70
+ def _check_model_available(self) -> bool:
71
+ """
72
+ 检查模型是否可用(兼容性方法,实际不使用)
73
+
74
+ Returns:
75
+ 总是返回 True(因为我们使用的是统一的 LLM 系统)
76
+ """
77
+ return True
78
+
79
+ def generate(
80
+ self,
81
+ prompt: str,
82
+ temperature: float = 0.7,
83
+ max_tokens: Optional[int] = None,
84
+ stream: bool = False
85
+ ) -> str:
86
+ """
87
+ 生成回答(兼容 OllamaLLM.generate 接口)
88
+
89
+ Args:
90
+ prompt: 输入 prompt
91
+ temperature: 温度参数(0.0-1.0),控制随机性
92
+ max_tokens: 最大生成 token 数(None 表示使用模型预设)
93
+ stream: 是否使用流式输出(当前不支持,总是返回完整结果)
94
+
95
+ Returns:
96
+ 生成的回答字符串
97
+ """
98
+ try:
99
+ # 将 prompt 转换为 LangChain 消息格式
100
+ messages = [HumanMessage(content=prompt)]
101
+
102
+ # 准备调用参数
103
+ invoke_kwargs = {}
104
+
105
+ # 如果 LLM 支持 temperature 参数
106
+ if hasattr(self.llm, 'temperature'):
107
+ # 临时设置 temperature(如果支持)
108
+ original_temp = getattr(self.llm, 'temperature', None)
109
+ try:
110
+ self.llm.temperature = temperature
111
+ except:
112
+ pass # 如果不支持设置,忽略
113
+
114
+ # 如果 LLM 支持 max_tokens 参数
115
+ if max_tokens and hasattr(self.llm, 'max_tokens'):
116
+ original_max_tokens = getattr(self.llm, 'max_tokens', None)
117
+ try:
118
+ self.llm.max_tokens = max_tokens
119
+ except:
120
+ pass # 如果不支持设置,忽略
121
+
122
+ # 调用 LangChain LLM
123
+ response = self.llm.invoke(messages, **invoke_kwargs)
124
+
125
+ # 恢复原始参数(如果之前修改过)
126
+ if hasattr(self.llm, 'temperature') and 'original_temp' in locals():
127
+ try:
128
+ self.llm.temperature = original_temp
129
+ except:
130
+ pass
131
+
132
+ if hasattr(self.llm, 'max_tokens') and 'original_max_tokens' in locals():
133
+ try:
134
+ self.llm.max_tokens = original_max_tokens
135
+ except:
136
+ pass
137
+
138
+ # 提取回答内容
139
+ if hasattr(response, 'content'):
140
+ answer = response.content
141
+ elif isinstance(response, str):
142
+ answer = response
143
+ else:
144
+ answer = str(response)
145
+
146
+ return answer.strip()
147
+
148
+ except Exception as e:
149
+ logger.error(f"⚠️ LLM 生成回答时出错: {e}")
150
+ raise RuntimeError(f"LLM 生成失败: {e}") from e
151
+
deep_agent_rag/rag/private_file_rag.py ADDED
The diff for this file is too large to render. See raw diff
 
deep_agent_rag/ui/calendar_interface.py ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # deep_agent_rag/ui/calendar_interface.py
2
+
3
+ import gradio as gr
4
+ from datetime import datetime, timedelta
5
+ import re
6
+ import json
7
+ import time
8
+
9
+ from ..agents.calendar_agent import generate_calendar_draft, create_calendar_draft
10
+ # Assuming is_using_local_llm might be used for warnings/status, similar to email_interface
11
+ # from ..utils.llm_utils import is_using_local_llm
12
+
13
+ # Agent log path for debugging (if needed)
14
+ log_path = "/Users/matthuang/Desktop/Deep_Agentic_AI_Tool/.cursor/debug.log"
15
+
16
+ def _create_calendar_interface():
17
+ """創建 Calendar Tool 界面"""
18
+ gr.Markdown(
19
+ """
20
+ ### 📅 智能行事曆管理助手
21
+
22
+ 使用 AI 根據您的完整提示自動生成行事曆事件草稿,您可以在創建前檢查和修改。
23
+
24
+ **使用方式:**
25
+ 1. **快速選擇**:點擊下方常見事件按鈕,自動生成草稿
26
+ 2. **自定義輸入**:在下方輸入完整的事件提示,包含:事件、日期、時間、地點、參與者
27
+ 3. 查看 AI 反思評估結果和改進建議(如有)
28
+ 4. 如果有缺失的資訊(如時間),系統會顯示下拉選單讓您選擇
29
+ 5. 檢查並修改生成的事件內容
30
+ 6. 確認無誤後點擊「創建事件」按鈕
31
+
32
+ **✨ 新功能:AI 迭代反思評估 + Google Maps 地點驗證**
33
+ - 系統會自動進行多輪反思評估(最多 3 輪)
34
+ - 自動驗證並標準化地址,計算交通時間
35
+ - 每輪評估後,如果有改進建議,會自動生成改進版本
36
+ - 改進後的版本會再次評估,直到 AI 認為滿意為止
37
+ """
38
+ )
39
+
40
+ # 快速選擇按鈕區域
41
+ gr.Markdown("### 🚀 快速選擇常見事件")
42
+ with gr.Row():
43
+ quick_meeting_btn = gr.Button("📋 團隊會議", variant="secondary", scale=1)
44
+ quick_client_btn = gr.Button("🤝 客戶拜訪", variant="secondary", scale=1)
45
+ quick_lunch_btn = gr.Button("🍽️ 午餐會議", variant="secondary", scale=1)
46
+ quick_oneonone_btn = gr.Button("💬 一對一會議", variant="secondary", scale=1)
47
+ with gr.Row():
48
+ quick_project_btn = gr.Button("📊 項目討論", variant="secondary", scale=1)
49
+ quick_training_btn = gr.Button("🎓 培訓/學習", variant="secondary", scale=1)
50
+ quick_social_btn = gr.Button("🎉 社交活動", variant="secondary", scale=1)
51
+ quick_custom_btn = gr.Button("✏️ 自定義輸入", variant="secondary", scale=1)
52
+
53
+ with gr.Row():
54
+ with gr.Column(scale=1):
55
+ # 單一 prompt 輸入
56
+ calendar_prompt_input = gr.Textbox(
57
+ label="📝 事件提示(包含事件、日期、時間、地點、參與者)",
58
+ placeholder="例如:明天下午2點團隊會議,討論項目進度,地點在會議室A,參與者包括john@example.com和mary@example.com",
59
+ lines=5,
60
+ value=""
61
+ )
62
+
63
+ # 按鈕
64
+ with gr.Row():
65
+ generate_draft_btn = gr.Button("📝 生成事件草稿", variant="primary", scale=1)
66
+ clear_calendar_btn = gr.Button("🗑️ 清除", variant="secondary", scale=1)
67
+
68
+ # 狀態顯示
69
+ calendar_status_display = gr.Textbox(
70
+ label="📊 狀態",
71
+ value="等待操作...",
72
+ interactive=False,
73
+ lines=2
74
+ )
75
+
76
+ # 反思結果顯示
77
+ calendar_reflection_display = gr.Textbox(
78
+ label="🔍 AI 反思評估",
79
+ value="等待生成事件...",
80
+ interactive=False,
81
+ lines=8,
82
+ visible=True
83
+ )
84
+
85
+ # 缺失資訊的補充區域(動態顯示)
86
+ missing_info_group = gr.Group(visible=False)
87
+ with missing_info_group:
88
+ gr.Markdown("**⚠️ 請補充以下缺失的資訊:**")
89
+
90
+ # 日期選擇(如果缺失)
91
+ missing_date_display = gr.Dropdown(
92
+ label="📆 選擇日期",
93
+ choices=[],
94
+ visible=False,
95
+ interactive=True
96
+ )
97
+
98
+ # 時間選擇(如果缺失)
99
+ missing_time_display = gr.Dropdown(
100
+ label="🕐 選擇時間",
101
+ choices=[],
102
+ visible=False,
103
+ interactive=True
104
+ )
105
+
106
+ fill_missing_btn = gr.Button("✅ 確認補充資訊", variant="primary", visible=False)
107
+
108
+ # 隱藏狀態變數,用於存儲 event_dict
109
+ event_dict_storage = gr.State(value={})
110
+
111
+ with gr.Column(scale=1):
112
+ # 事件詳情顯示和編輯區域
113
+ event_summary_display = gr.Textbox(
114
+ label="📌 事件標題",
115
+ placeholder="事件標題將在這裡顯示",
116
+ lines=1,
117
+ interactive=True
118
+ )
119
+
120
+ event_start_display = gr.Textbox(
121
+ label="🕐 開始時間",
122
+ placeholder="開始時間將在這裡顯示(格式: YYYY-MM-DDTHH:MM:SS+08:00)",
123
+ lines=1,
124
+ interactive=True
125
+ )
126
+
127
+ event_end_display = gr.Textbox(
128
+ label="🕐 結束時間",
129
+ placeholder="結束時間將在這裡顯示(格式: YYYY-MM-DDTHH:MM:SS+08:00)",
130
+ lines=1,
131
+ interactive=True
132
+ )
133
+
134
+ event_description_display = gr.Textbox(
135
+ label="📄 事件描述(可編輯)",
136
+ placeholder="事件描述將在這裡顯示,您可以編輯",
137
+ lines=6,
138
+ interactive=True
139
+ )
140
+
141
+ event_location_display = gr.Textbox(
142
+ label="📍 地點(可編輯,已自動驗證並標準化)",
143
+ placeholder="事件地點將在這裡顯示,您可以編輯",
144
+ lines=2,
145
+ interactive=True
146
+ )
147
+
148
+ event_attendees_display = gr.Textbox(
149
+ label="👥 參與者郵箱(可編輯,多個用逗號分隔)",
150
+ placeholder="參與者郵箱將在這裡顯示,您可以編輯",
151
+ lines=1,
152
+ interactive=True
153
+ )
154
+
155
+ # 創建按鈕
156
+ create_event_btn = gr.Button("✅ 創建事件", variant="primary", scale=1)
157
+
158
+ # 操作結果顯示
159
+ calendar_result_display = gr.Textbox(
160
+ label="📊 操作結果",
161
+ lines=8,
162
+ interactive=False
163
+ )
164
+
165
+ # 生成時間選項(每30分鐘一個選項)
166
+ def generate_time_options():
167
+ """生成時間選項列表"""
168
+ times = []
169
+ for hour in range(24):
170
+ for minute in [0, 30]:
171
+ time_str = f"{hour:02d}:{minute:02d}"
172
+ times.append(time_str)
173
+ return times
174
+
175
+ # 生成日期選項(今天、明天、後天,以及未來7天)
176
+ def generate_date_options():
177
+ """生成日期選項列表"""
178
+ dates = []
179
+ today = datetime.now()
180
+ date_names = ["今天", "明天", "後天"]
181
+
182
+ for i in range(3):
183
+ date_obj = today + timedelta(days=i)
184
+ date_str = date_obj.strftime('%Y-%m-%d')
185
+ dates.append(f"{date_names[i]} ({date_str})")
186
+
187
+ for i in range(3, 7):
188
+ date_obj = today + timedelta(days=i)
189
+ date_str = date_obj.strftime('%Y-%m-%d')
190
+ dates.append(date_str)
191
+
192
+ return dates
193
+
194
+ # 快速選擇事件模板生成函數
195
+ def generate_quick_prompt(event_type: str) -> str:
196
+ """根據事件類型生成預設提示"""
197
+ from datetime import datetime, timedelta
198
+
199
+ # 獲取明天的日期
200
+ tomorrow = datetime.now() + timedelta(days=1)
201
+ tomorrow_str = tomorrow.strftime("%Y-%m-%d")
202
+
203
+ templates = {
204
+ "meeting": f"明天下午2點團隊會議,討論項目進度和下週計劃,地點在會議室,參與者包括團隊成員",
205
+ "client": f"明天上午10點客戶拜訪,討論合作方案和需求,地點在客戶公司或會議室",
206
+ "lunch": f"明天中午12點午餐會議,與合作夥伴討論業務合作,地點在附近的餐廳",
207
+ "oneonone": f"明天下午3點一對一會議,討論工作進展和職業發展,地點在會議室或咖啡廳",
208
+ "project": f"明天上午9點項目討論會議,審查項目進度和解決問題,地點在項目室,參與者包括項目團隊",
209
+ "training": f"明天下午2點培訓課程,學習新技能和最佳實踐,地點在培訓室或線上",
210
+ "social": f"明天晚上6點團隊聚餐,慶祝項目完成,地點在餐廳,參與者包括團隊成員",
211
+ "custom": "" # 自定義,返回空讓用戶輸入
212
+ }
213
+
214
+ return templates.get(event_type, "")
215
+
216
+ # 快速選擇按鈕處理函數(自動生成草稿)
217
+ def quick_select_and_generate(event_type: str):
218
+ """快速選擇事件類型並自動生成草稿"""
219
+ prompt = generate_quick_prompt(event_type)
220
+ if not prompt:
221
+ # 如果是自定義,只返回空提示,不自動生成
222
+ return (
223
+ prompt, # calendar_prompt_input
224
+ "請在下方輸入框中輸入事件提示,然後點擊「生成事件草稿」", # calendar_status_display
225
+ "等待輸入...", # calendar_reflection_display
226
+ gr.update(visible=False), # missing_info_group
227
+ gr.update(visible=False, choices=[]), # missing_date_display
228
+ gr.update(visible=False, choices=[]), # missing_time_display
229
+ gr.update(visible=False), # fill_missing_btn
230
+ "", "", "", "", "", "", # event fields
231
+ {},
232
+ "" # calendar_result_display
233
+ )
234
+
235
+ # 自動生成草稿(調用 generate_draft 並返回所有輸出)
236
+ draft_result = generate_draft(prompt)
237
+ # generate_draft 返回的格式是:(status, reflection_display, missing_info_group, ...)
238
+ # 但我們需要返回 (prompt, status, reflection_display, ...)
239
+ # draft_result 是一個元組,我們需要將 prompt 添加到開頭
240
+ return (prompt,) + draft_result
241
+
242
+ def quick_select_meeting():
243
+ """快速選擇:團隊會議"""
244
+ return quick_select_and_generate("meeting")
245
+
246
+ def quick_select_client():
247
+ """快速選擇:客戶拜訪"""
248
+ return quick_select_and_generate("client")
249
+
250
+ def quick_select_lunch():
251
+ """快速選擇:午餐會議"""
252
+ return quick_select_and_generate("lunch")
253
+
254
+ def quick_select_oneonone():
255
+ """快速選擇:一對一會議"""
256
+ return quick_select_and_generate("oneonone")
257
+
258
+ def quick_select_project():
259
+ """快速選擇:項目討論"""
260
+ return quick_select_and_generate("project")
261
+
262
+ def quick_select_training():
263
+ """快速選擇:培訓/學習"""
264
+ return quick_select_and_generate("training")
265
+
266
+ def quick_select_social():
267
+ """快速選擇:社交活動"""
268
+ return quick_select_and_generate("social")
269
+
270
+ def quick_select_custom():
271
+ """快速選擇:自定義輸入(只清空,不自動生成)"""
272
+ return (
273
+ "", # calendar_prompt_input
274
+ "請在下方輸入框中輸入事件提示,然後點擊「生成事件草稿」", # calendar_status_display
275
+ "等待輸入...", # calendar_reflection_display
276
+ gr.update(visible=False), # missing_info_group
277
+ gr.update(visible=False, choices=[]), # missing_date_display
278
+ gr.update(visible=False, choices=[]), # missing_time_display
279
+ gr.update(visible=False), # fill_missing_btn
280
+ "", "", "", "", "", "", # event fields
281
+ {},
282
+ "" # calendar_result_display
283
+ )
284
+
285
+ # 事件處理函數
286
+ def generate_draft(prompt):
287
+ """生成行事曆事件草稿(包含反思功能)"""
288
+ if not prompt or not prompt.strip():
289
+ return (
290
+ "❌ 請輸入事件提示",
291
+ "❌ 請輸入事件提示",
292
+ gr.update(visible=False),
293
+ gr.update(visible=False, choices=[]),
294
+ gr.update(visible=False, choices=[]),
295
+ gr.update(visible=False),
296
+ "", "", "", "", "", "", "",
297
+ "❌ 請輸入事件提示"
298
+ )
299
+
300
+ try:
301
+ status_msg = "🔄 正在生成事件草稿..."
302
+
303
+ # 生成事件草稿(包含反思功能)
304
+ event_dict, status, missing_info, reflection_result, was_improved = generate_calendar_draft(
305
+ prompt.strip(), enable_reflection=True
306
+ )
307
+
308
+ if not event_dict:
309
+ return (
310
+ status,
311
+ gr.update(visible=False),
312
+ gr.update(visible=False, choices=[]),
313
+ gr.update(visible=False, choices=[]),
314
+ gr.update(visible=False),
315
+ "", "", "", "", "", "", "",
316
+ status
317
+ )
318
+
319
+ # 格式化反思結果顯示
320
+ if reflection_result:
321
+ # 計算反思輪數
322
+ reflection_count = reflection_result.count("【第") if "【第" in reflection_result else 0
323
+
324
+ if was_improved:
325
+ if reflection_count > 1:
326
+ reflection_display = (
327
+ f"🔍 **AI 迭代反思評估結果**(共 {reflection_count} 輪)\n\n"
328
+ f"{reflection_result}\n\n"
329
+ f"✨ **已自動應用改進建議,經過 {reflection_count} 輪優化,當前顯示的是最終優化版本**"
330
+ )
331
+ else:
332
+ reflection_display = (
333
+ f"🔍 **AI 反思評估結果**\n\n"
334
+ f"{reflection_result}\n\n"
335
+ f"✨ **已自動應用改進建議,當前顯示的是優化後的版本**"
336
+ )
337
+ else:
338
+ reflection_display = (
339
+ f"🔍 **AI 反思評估結果**\n\n"
340
+ f"{reflection_result}\n\n"
341
+ f"✅ **事件質量良好,無需改進**"
342
+ )
343
+ else:
344
+ reflection_display = "⚠️ 反思功能未返回結果"
345
+
346
+ # 【Google Maps 整合】添加地點建議訊息
347
+ location_suggestion = event_dict.get("location_suggestion", "")
348
+ if location_suggestion:
349
+ # 將地點建議添加到狀態訊息中
350
+ if status:
351
+ status = f"{status}\n\n🗺️ **地點資訊:**\n{location_suggestion}"
352
+ else:
353
+ status = f"🗺️ **地點資訊:**\n{location_suggestion}"
354
+
355
+ # 檢查是否有缺失資訊
356
+ has_missing = bool(missing_info)
357
+
358
+ if has_missing:
359
+ # 顯示缺失資訊區域
360
+ date_visible = missing_info.get("date", False)
361
+ time_visible = missing_info.get("time", False)
362
+
363
+ date_choices = generate_date_options() if date_visible else []
364
+ time_choices = generate_time_options() if time_visible else []
365
+
366
+ return (
367
+ status,
368
+ reflection_display,
369
+ gr.update(visible=True), # 顯示缺失資訊區域
370
+ gr.update(visible=date_visible, choices=date_choices, value=date_choices[0] if date_choices else None),
371
+ gr.update(visible=time_visible, choices=time_choices, value=time_choices[0] if time_choices else None),
372
+ gr.update(visible=True), # 顯示確認按鈕
373
+ event_dict.get("summary", ""),
374
+ event_dict.get("start_datetime", ""),
375
+ event_dict.get("end_datetime", ""),
376
+ event_dict.get("description", ""),
377
+ event_dict.get("location", ""),
378
+ event_dict.get("attendees", ""),
379
+ event_dict, # 傳遞完整的事件字典以便後續使用
380
+ ""
381
+ )
382
+ else:
383
+ # 沒有缺失資訊,直接顯示結果
384
+ return (
385
+ status,
386
+ reflection_display,
387
+ gr.update(visible=False),
388
+ gr.update(visible=False, choices=[]),
389
+ gr.update(visible=False, choices=[]),
390
+ gr.update(visible=False),
391
+ event_dict.get("summary", ""),
392
+ event_dict.get("start_datetime", ""),
393
+ event_dict.get("end_datetime", ""),
394
+ event_dict.get("description", ""),
395
+ event_dict.get("location", ""),
396
+ event_dict.get("attendees", ""),
397
+ event_dict,
398
+ ""
399
+ )
400
+ except Exception as e:
401
+ error_msg = f"❌ 發生錯誤:{str(e)}"
402
+ print(f"Calendar Tool 錯誤:{e}")
403
+ import traceback
404
+ traceback.print_exc()
405
+ return (
406
+ "❌ 發生錯誤",
407
+ f"❌ 發生錯誤:{str(e)}",
408
+ gr.update(visible=False),
409
+ gr.update(visible=False, choices=[]),
410
+ gr.update(visible=False, choices=[]),
411
+ gr.update(visible=False),
412
+ "", "", "", "", "", "", {},
413
+ error_msg
414
+ )
415
+
416
+ def fill_missing_info(event_dict_storage, selected_date, selected_time):
417
+ """填充缺失的資訊"""
418
+ if not event_dict_storage:
419
+ return (
420
+ "❌ 沒有事件資料",
421
+ gr.update(visible=False),
422
+ gr.update(visible=False, choices=[]),
423
+ gr.update(visible=False, choices=[]),
424
+ gr.update(visible=False),
425
+ "", "", "", "", "", "",
426
+ {}
427
+ )
428
+
429
+ # 更新日期和時間
430
+ if selected_date:
431
+ # 從選項中提取日期字串(例如:"明天 (2026-01-25)" -> "2026-01-25")
432
+ if "(" in selected_date:
433
+ date_str = selected_date.split("(")[1].split(")")[0]
434
+ else:
435
+ date_str = selected_date
436
+ else:
437
+ date_str = event_dict_storage.get("date", "今天")
438
+
439
+ if selected_time:
440
+ time_str = selected_time
441
+ else:
442
+ time_str = "09:00" # 預設時間
443
+
444
+ # 重新解析日期和時間
445
+ from ..agents.calendar_agent import parse_datetime # Import here to avoid circular dependency or unnecessary global import
446
+ start_datetime, end_datetime = parse_datetime(date_str, time_str)
447
+
448
+ # 更新事件字典
449
+ event_dict_storage["start_datetime"] = start_datetime
450
+ event_dict_storage["end_datetime"] = end_datetime
451
+
452
+ return (
453
+ "✅ 資訊已補充,請檢查並創建事件",
454
+ gr.update(visible=False), # 隱藏缺失資訊區域
455
+ gr.update(visible=False, choices=[]),
456
+ gr.update(visible=False, choices=[]),
457
+ gr.update(visible=False),
458
+ event_dict_storage.get("summary", ""),
459
+ start_datetime,
460
+ end_datetime,
461
+ event_dict_storage.get("description", ""),
462
+ event_dict_storage.get("location", ""),
463
+ event_dict_storage.get("attendees", ""),
464
+ event_dict_storage
465
+ )
466
+
467
+ def create_event(summary, start_datetime, end_datetime, description, location, attendees):
468
+ """創建行事曆事件"""
469
+ if not summary or not summary.strip():
470
+ return "❌ 請輸入事件標題", "❌ 請輸入事件標題"
471
+
472
+ if not start_datetime or not start_datetime.strip():
473
+ return "❌ 請輸入開始時間", "❌ 請輸入開始時間"
474
+
475
+ if not end_datetime or not end_datetime.strip():
476
+ return "❌ 請輸入結束時間", "❌ 請輸入結束時間"
477
+
478
+ try:
479
+ status_msg = "🔄 正在創建事件..."
480
+
481
+ # 構建事件字典
482
+ event_dict = {
483
+ "summary": summary.strip(),
484
+ "start_datetime": start_datetime.strip(),
485
+ "end_datetime": end_datetime.strip(),
486
+ "description": description.strip() if description else "",
487
+ "location": location.strip() if location else "",
488
+ "attendees": attendees.strip() if attendees else "",
489
+ "timezone": "Asia/Taipei"
490
+ }
491
+
492
+ # 創建事件
493
+ result = create_calendar_draft(event_dict)
494
+
495
+ return "✅ 事件已創建", result
496
+ except Exception as e:
497
+ error_msg = f"❌ 創建事件時發生錯誤:{str(e)}"
498
+ print(f"Calendar Tool 錯誤:{e}")
499
+ import traceback
500
+ traceback.print_exc()
501
+ return "❌ 發生錯誤", error_msg
502
+
503
+ def clear_calendar():
504
+ """清除行事曆相關輸入和輸出"""
505
+ return (
506
+ "", # prompt
507
+ "等待操作...", # status
508
+ "等待生成事件...", # reflection_display
509
+ gr.update(visible=False), # missing_info_group
510
+ gr.update(visible=False, choices=[]), # missing_date
511
+ gr.update(visible=False, choices=[]), # missing_time
512
+ gr.update(visible=False), # fill_missing_btn
513
+ "", "", "", "", "", "", # event fields
514
+ {},
515
+ "" # result
516
+ )
517
+
518
+ # 綁定事件
519
+ generate_draft_btn.click(
520
+ fn=generate_draft,
521
+ inputs=[calendar_prompt_input],
522
+ outputs=[
523
+ calendar_status_display,
524
+ calendar_reflection_display,
525
+ missing_info_group,
526
+ missing_date_display,
527
+ missing_time_display,
528
+ fill_missing_btn,
529
+ event_summary_display,
530
+ event_start_display,
531
+ event_end_display,
532
+ event_description_display,
533
+ event_location_display,
534
+ event_attendees_display,
535
+ event_dict_storage,
536
+ calendar_result_display
537
+ ]
538
+ )
539
+
540
+ # 綁定快速選擇按鈕(自動填充提示並生成草稿)
541
+ quick_outputs = [
542
+ calendar_prompt_input, # 更新提示輸入框
543
+ calendar_status_display,
544
+ calendar_reflection_display,
545
+ missing_info_group,
546
+ missing_date_display,
547
+ missing_time_display,
548
+ fill_missing_btn,
549
+ event_summary_display,
550
+ event_start_display,
551
+ event_end_display,
552
+ event_description_display,
553
+ event_location_display,
554
+ event_attendees_display,
555
+ event_dict_storage,
556
+ calendar_result_display
557
+ ]
558
+
559
+ quick_meeting_btn.click(fn=quick_select_meeting, outputs=quick_outputs)
560
+ quick_client_btn.click(fn=quick_select_client, outputs=quick_outputs)
561
+ quick_lunch_btn.click(fn=quick_select_lunch, outputs=quick_outputs)
562
+ quick_oneonone_btn.click(fn=quick_select_oneonone, outputs=quick_outputs)
563
+ quick_project_btn.click(fn=quick_select_project, outputs=quick_outputs)
564
+ quick_training_btn.click(fn=quick_select_training, outputs=quick_outputs)
565
+ quick_social_btn.click(fn=quick_select_social, outputs=quick_outputs)
566
+ quick_custom_btn.click(fn=quick_select_custom, outputs=quick_outputs)
567
+
568
+ fill_missing_btn.click(
569
+ fn=fill_missing_info,
570
+ inputs=[event_dict_storage, missing_date_display, missing_time_display],
571
+ outputs=[
572
+ calendar_status_display,
573
+ missing_info_group,
574
+ missing_date_display,
575
+ missing_time_display,
576
+ fill_missing_btn,
577
+ event_summary_display,
578
+ event_start_display,
579
+ event_end_display,
580
+ event_description_display,
581
+ event_location_display,
582
+ event_attendees_display,
583
+ event_dict_storage
584
+ ]
585
+ )
586
+
587
+ create_event_btn.click(
588
+ fn=create_event,
589
+ inputs=[
590
+ event_summary_display,
591
+ event_start_display,
592
+ event_end_display,
593
+ event_description_display,
594
+ event_location_display,
595
+ event_attendees_display
596
+ ],
597
+ outputs=[calendar_status_display, calendar_result_display]
598
+ )
599
+
600
+ clear_calendar_btn.click(
601
+ fn=clear_calendar,
602
+ outputs=[
603
+ calendar_prompt_input,
604
+ calendar_status_display,
605
+ calendar_reflection_display,
606
+ missing_info_group,
607
+ missing_date_display,
608
+ missing_time_display,
609
+ fill_missing_btn,
610
+ event_summary_display,
611
+ event_start_display,
612
+ event_end_display,
613
+ event_description_display,
614
+ event_location_display,
615
+ event_attendees_display,
616
+ event_dict_storage,
617
+ calendar_result_display
618
+ ]
619
+ )
620
+
621
+ # 示例
622
+ gr.Examples(
623
+ examples=[
624
+ "明天下午2點團隊會議,討論項目進度,地點在會議室A,參與者包括john@example.com",
625
+ "2026-01-25 上午9點產品發布會,介紹新功能和改進,地點在總部大樓",
626
+ "後天下午3點客戶會議,討論合作細節,參與者包括客戶代表",
627
+ "下週一上午10點技術分享會,分享最新的 AI 技術,地點在研發中心"
628
+ ],
629
+ inputs=[calendar_prompt_input]
630
+ )
631
+
632
+ # 頁腳說明
633
+ gr.Markdown(
634
+ """
635
+ ---
636
+ **注意事項:**
637
+ 1. 使用 Google Calendar API 管理行事曆事件
638
+ 2. 首次使用需要在專案根目錄放置 `credentials.json`(從 Google Cloud Console 下載的 OAuth2 憑證)
639
+ 3. 首次運行時會自動開啟瀏覽器進行授權,授權後會生成 `token.json` 文件
640
+ 4. 事件內容由 AI 自動生成,請在創建前檢查結果
641
+ 5. 在提示中包含所有資訊:事件、日期、時間、地點、參與者
642
+ 6. 如果缺少日期或時間,系統會顯示下拉選單讓您選擇
643
+ 7. 日期格式支援:YYYY-MM-DD(例如:2026-01-25)或相對日期(今天、明天、後天)
644
+ 8. 時間格式支援:24小時制(14:00)或12小時制(2:00 PM)
645
+
646
+ **設置步驟:**
647
+ - 前往 [Google Cloud Console](https://console.cloud.google.com/) 創建專案
648
+ - 啟用 Google Calendar API
649
+ - 創建 OAuth2 憑證並下載為 `credentials.json`
650
+ - 將 `credentials.json` 放在專案根目錄
651
+ - 確保授予 Calendar API 的完整存取權限
652
+ """
653
+ )
deep_agent_rag/ui/email_interface.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # deep_agent_rag/ui/email_interface.py
2
+
3
+ import gradio as gr
4
+ import re
5
+ import json
6
+ import time
7
+
8
+ from ..agents.email_agent import generate_email_draft, send_email_draft
9
+ from ..config import EMAIL_SENDER
10
+ from ..utils.llm_utils import is_using_local_llm # Assuming this might be used for warnings/status
11
+
12
+ # Agent log path for debugging (if needed)
13
+ log_path = "/Users/matthuang/Desktop/Deep_Agentic_AI_Tool/.cursor/debug.log"
14
+
15
+ def _create_email_interface():
16
+ """創建 Email Tool 界面"""
17
+ gr.Markdown(
18
+ f"""
19
+ ### 📧 智能郵件助手
20
+
21
+ 使用 AI 根據您的關鍵提示自動生成專業郵件草稿,您可以在發送前檢查和修改。
22
+
23
+ **寄件者:** {EMAIL_SENDER}
24
+
25
+ **使用方式:**
26
+ 1. 在下方輸入郵件提示(例如:"寫一封感謝信"、"邀請參加會議"等)
27
+ 2. 輸入收件人 Gmail 郵箱地址(僅支援 @gmail.com 或 @googlemail.com)
28
+ 3. 點擊「生成郵件草稿」按鈕
29
+ 4. 查看 AI 反思評估結果和改進建議(如有)
30
+ 5. 檢查並修改生成的郵件內容(特別是簽名部分)
31
+ 6. 確認無誤後點擊「發送郵件」按鈕
32
+
33
+ **✨ 新功能:AI 迭代反思評估**
34
+ - 系統會自動進行多輪反思評估(最多 3 輪)
35
+ - 每輪評估後,如果有改進建議,會自動生成改進版本
36
+ - 改進後的版本會再次評估,直到 AI 認為滿意為止
37
+ - 您可以看到完整的反思過程和每輪的改進建議
38
+
39
+ **注意:此工具僅支援 Gmail 郵箱,收件人必須使用 Gmail 郵箱地址。**
40
+ """
41
+ )
42
+
43
+ with gr.Row():
44
+ with gr.Column(scale=1):
45
+ # 郵件提示輸入
46
+ email_prompt_input = gr.Textbox(
47
+ label="📝 郵件提示",
48
+ placeholder="例如:寫一封感謝信,感謝對方在項目中的幫助",
49
+ lines=5,
50
+ value="寫一封專業的郵件,介紹我們的 AI 產品"
51
+ )
52
+
53
+ # 收件人輸入
54
+ recipient_input = gr.Textbox(
55
+ label="📮 收件人郵箱(僅支援 Gmail)",
56
+ placeholder="recipient@gmail.com",
57
+ lines=1
58
+ )
59
+
60
+ # 按鈕
61
+ with gr.Row():
62
+ generate_draft_btn = gr.Button("📝 生成郵件草稿", variant="primary", scale=1)
63
+ clear_email_btn = gr.Button("🗑️ 清除", variant="secondary", scale=1)
64
+
65
+ # 狀態顯示
66
+ email_status_display = gr.Textbox(
67
+ label="📊 狀態",
68
+ value="等待操作...",
69
+ interactive=False,
70
+ lines=2
71
+ )
72
+
73
+ # 反思結果顯示
74
+ email_reflection_display = gr.Textbox(
75
+ label="🔍 AI 反思評估",
76
+ value="等待生成郵件...",
77
+ interactive=False,
78
+ lines=8,
79
+ visible=True
80
+ )
81
+
82
+ with gr.Column(scale=1):
83
+ # 郵件主題(可編輯)
84
+ email_subject_input = gr.Textbox(
85
+ label="📌 郵件主題",
86
+ placeholder="郵件主題將在這裡顯示,您可以編輯",
87
+ lines=1,
88
+ interactive=True
89
+ )
90
+
91
+ # 郵件正文(可編輯)
92
+ email_body_input = gr.Textbox(
93
+ label="📄 郵件正文(可編輯)",
94
+ placeholder="郵件內容將在這裡顯示,您可以編輯",
95
+ lines=15,
96
+ interactive=True
97
+ )
98
+
99
+ # 發送按鈕
100
+ send_draft_btn = gr.Button("📧 發送郵件", variant="primary", scale=1)
101
+
102
+ # 發送結果顯示
103
+ email_result_display = gr.Textbox(
104
+ label="📊 發送結果",
105
+ lines=5,
106
+ interactive=False
107
+ )
108
+
109
+ # 事件處理函數
110
+ def generate_draft(prompt, recipient):
111
+ """生成郵件草稿(包含反思功能)"""
112
+ if not prompt or not prompt.strip():
113
+ return "❌ 請輸入郵件提示", "", "", "❌ 請輸入郵件提示", "❌ 請輸入郵件提示"
114
+
115
+ if not recipient or not recipient.strip():
116
+ return "❌ 請輸入收件人郵箱", "", "", "❌ 請輸入收件人郵箱", "❌ 請輸入收件人郵箱"
117
+
118
+ # 驗證郵箱格式和 Gmail 限制
119
+ if "@" not in recipient or "." not in recipient.split("@")[1]:
120
+ return "❌ 郵箱格式不正確", "", "", "❌ 郵箱格式不正確,請輸入有效的郵箱地址", "❌ 郵箱格式不正確,請輸入有效的郵箱地址"
121
+
122
+ # 驗證是否為 Gmail 郵箱
123
+ recipient_lower = recipient.strip().lower()
124
+ if not (recipient_lower.endswith("@gmail.com") or recipient_lower.endswith("@googlemail.com")):
125
+ return "❌ 僅支援 Gmail 郵箱", "", "", "❌ 此工具僅支援 Gmail 郵箱(@gmail.com 或 @googlemail.com),請輸入 Gmail 郵箱地址", "❌ 此工具僅支援 Gmail 郵箱(@gmail.com 或 @googlemail.com),請輸入 Gmail 郵箱地址"
126
+
127
+ try:
128
+ status_msg = "🔄 正在生成郵件草稿..."
129
+ reflection_msg = "🔄 正在生成郵件草稿..."
130
+
131
+ # 生成郵件草稿(包含反思功能,會自動改進)
132
+ subject, body, status, reflection_result, was_improved = generate_email_draft(
133
+ prompt, recipient.strip(), enable_reflection=True
134
+ )
135
+
136
+ if subject and body:
137
+ # 格式化反思結果顯示
138
+ if reflection_result:
139
+ # 計算反思輪數
140
+ reflection_count = reflection_result.count("【第") if "【第" in reflection_result else 0
141
+
142
+ if was_improved:
143
+ if reflection_count > 1:
144
+ reflection_display = (
145
+ f"🔍 **AI 迭代反思評估結果**(共 {reflection_count} 輪)\n\n"
146
+ f"{reflection_result}\n\n"
147
+ f"✨ **已自動應用改進建議,經過 {reflection_count} 輪優化,當前顯示的是最終優化版本**"
148
+ )
149
+ else:
150
+ reflection_display = (
151
+ f"🔍 **AI 反思評估結果**\n\n"
152
+ f"{reflection_result}\n\n"
153
+ f"✨ **已自動應用改進建議,當前顯示的是優化後的版本**"
154
+ )
155
+ else:
156
+ reflection_display = (
157
+ f"🔍 **AI 反思評估結果**\n\n"
158
+ f"{reflection_result}\n\n"
159
+ f"✅ **郵件質量良好,無需改進**"
160
+ )
161
+ else:
162
+ reflection_display = "⚠️ 反思功能未返回結果"
163
+
164
+ return status, subject, body, "", reflection_display
165
+ else:
166
+ return status, "", "", status, "❌ 生成失敗,無法進行反思評估"
167
+ except Exception as e:
168
+ error_msg = f"❌ 發生錯誤:{str(e)}"
169
+ print(f"Email Tool 錯誤:{e}")
170
+ import traceback
171
+ traceback.print_exc()
172
+ return "❌ 發生錯誤", "", "", error_msg, f"❌ 發生錯誤:{str(e)}"
173
+
174
+ def send_draft(recipient, subject, body):
175
+ """發送已編輯的郵件草稿"""
176
+ if not recipient or not recipient.strip():
177
+ return "❌ 請輸入收件人郵箱", "❌ 請輸入收件人郵箱"
178
+
179
+ if not subject or not subject.strip():
180
+ return "❌ 請輸入郵件主題", "❌ 請輸入郵件主題"
181
+
182
+ if not body or not body.strip():
183
+ return "❌ 請輸入郵件內容", "❌ 請輸入郵件內容"
184
+
185
+ # 驗證郵箱格式和 Gmail 限制
186
+ if "@" not in recipient or "." not in recipient.split("@")[1]:
187
+ return "❌ 郵箱格式不正確", "❌ 郵箱格式不正確,請輸入有效的郵箱地址"
188
+
189
+ # 驗證是否為 Gmail 郵箱
190
+ recipient_lower = recipient.strip().lower()
191
+ if not (recipient_lower.endswith("@gmail.com") or recipient_lower.endswith("@googlemail.com")):
192
+ return "❌ 僅支援 Gmail 郵箱", "❌ 此工具僅支援 Gmail 郵箱(@gmail.com 或 @googlemail.com),請輸入 Gmail 郵箱地址"
193
+
194
+ try:
195
+ status_msg = "🔄 正在發送郵件..."
196
+
197
+ # 發送郵件
198
+ result = send_email_draft(recipient.strip(), subject.strip(), body.strip())
199
+
200
+ return "✅ 郵件已發送", result
201
+ except Exception as e:
202
+ error_msg = f"❌ 發送郵件時發生錯誤:{str(e)}"
203
+ print(f"Email Tool 錯誤:{e}")
204
+ import traceback
205
+ traceback.print_exc()
206
+ return "❌ 發生錯誤", error_msg
207
+
208
+ def clear_email():
209
+ """清除郵件相關輸入和輸出"""
210
+ return "", "", "等待操作...", "", "", "等待生成郵件..."
211
+
212
+ # 綁定事件
213
+ generate_draft_btn.click(
214
+ fn=generate_draft,
215
+ inputs=[email_prompt_input, recipient_input],
216
+ outputs=[email_status_display, email_subject_input, email_body_input, email_result_display, email_reflection_display]
217
+ )
218
+
219
+ send_draft_btn.click(
220
+ fn=send_draft,
221
+ inputs=[recipient_input, email_subject_input, email_body_input],
222
+ outputs=[email_status_display, email_result_display]
223
+ )
224
+
225
+ clear_email_btn.click(
226
+ fn=clear_email,
227
+ outputs=[email_prompt_input, recipient_input, email_status_display, email_subject_input, email_body_input, email_result_display, email_reflection_display]
228
+ )
229
+
230
+ # 示例
231
+ gr.Examples(
232
+ examples=[
233
+ ["寫一封感謝信,感謝對方在項目中的幫助和支持", "example@gmail.com"],
234
+ ["邀請參加下週的產品發布會", "colleague@gmail.com"],
235
+ ["詢問項目進度並提供更新", "partner@gmail.com"],
236
+ ["發送會議記錄和後續行動項目", "team@gmail.com"]
237
+ ],
238
+ inputs=[email_prompt_input, recipient_input]
239
+ )
240
+
241
+ # 頁腳說明
242
+ gr.Markdown(
243
+ f"""
244
+ ---
245
+ **注意事項:**
246
+ 1. 使用 Gmail API 發送郵件,避免被歸類為垃圾郵件
247
+ 2. **此工具僅支援 Gmail 郵箱,收件人必須使用 @gmail.com 或 @googlemail.com 結尾的郵箱地址**
248
+ 3. 首次使用需要在專案根目錄放置 `credentials.json`(從 Google Cloud Console 下載的 OAuth2 憑證)
249
+ 4. 首次運行時會自動開啟瀏覽器進行授權,授權後會生成 `token.json` 文件
250
+ 5. 郵件內容由 AI 自動生成,請在發送前檢查結果
251
+ 6. 寄件者固定為:{EMAIL_SENDER}
252
+
253
+ **設置步驟:**
254
+ - 前往 [Google Cloud Console](https://console.cloud.google.com/) 創建專案
255
+ - 啟用 Gmail API
256
+ - 創建 OAuth2 憑證並下載為 `credentials.json`
257
+ - 將 `credentials.json` 放在專案根目錄
258
+ """
259
+ )
deep_agent_rag/ui/gradio_interface.py CHANGED
@@ -5,12 +5,17 @@ Gradio 界面模組
5
  import uuid
6
  import re
7
  import time
 
 
8
  from typing import Iterator, Tuple
9
  import gradio as gr
10
  from langchain_core.messages import HumanMessage
11
 
12
  # graph 和 rag_retriever 將從外部傳入,不在這裡導入
13
  from ..utils.llm_utils import get_llm_type, is_using_local_llm
 
 
 
14
 
15
 
16
  def run_research_agent(query: str, graph, thread_id: str = None) -> Iterator[Tuple[str, str, str, str, str]]:
@@ -204,9 +209,9 @@ def create_gradio_interface(graph):
204
  gr.Markdown(
205
  """
206
  <div class="header">
207
- <h1>🚀 Deep Research Agent with RAG (Local MLX)</h1>
208
  <p><strong>功能特色:</strong></p>
209
- <p>📊 股票資訊查詢 | 🌐 網路搜尋 | 📚 PDF 知識庫查詢(Tree of Thoughts 論文)| 📧 智能郵件助手 | 📅 智能行事曆管理</p>
210
  <p><strong>智能規劃:</strong> 系統會根據問題類型自動選擇合適的研究工具</p>
211
  <p><strong>本地模型:</strong> 使用 MLX 本地模型,保護隱私,無需 API 金鑰</p>
212
  </div>
@@ -227,6 +232,10 @@ def create_gradio_interface(graph):
227
  # Tab 3: Calendar Tool
228
  with gr.Tab("📅 Calendar Tool"):
229
  _create_calendar_interface()
 
 
 
 
230
 
231
  return demo
232
 
@@ -363,895 +372,8 @@ def _create_research_interface(graph):
363
  )
364
 
365
 
366
- def _create_email_interface():
367
- """創建 Email Tool 界面"""
368
- from ..agents.email_agent import generate_email_draft, send_email_draft
369
- from ..config import EMAIL_SENDER
370
-
371
- gr.Markdown(
372
- f"""
373
- ### 📧 智能郵件助手
374
-
375
- 使用 AI 根據您的關鍵提示自動生成專業郵件草稿,您可以在發送前檢查和修改。
376
-
377
- **寄件者:** {EMAIL_SENDER}
378
-
379
- **使用方式:**
380
- 1. 在下方輸入郵件提示(例如:"寫一封感謝信"、"邀請參加會議"等)
381
- 2. 輸入收件人 Gmail 郵箱地址(僅支援 @gmail.com 或 @googlemail.com)
382
- 3. 點擊「生成郵件草稿」按鈕
383
- 4. 查看 AI 反思評估結果和改進建議(如有)
384
- 5. 檢查並修改生成的郵件內容(特別是簽名部分)
385
- 6. 確認無誤後點擊「發送郵件」按鈕
386
-
387
- **✨ 新功能:AI 迭代反思評估**
388
- - 系統會自動進行多輪反思評估(最多 3 輪)
389
- - 每輪評估後,如果有改進建議,會自動生成改進版本
390
- - 改進後的版本會再次評估,直到 AI 認為滿意為止
391
- - 您可以看到完整的反思過程和每輪的改進建議
392
-
393
- **注意:此工具僅支援 Gmail 郵箱,收件人必須使用 Gmail 郵箱地址。**
394
- """
395
- )
396
-
397
- with gr.Row():
398
- with gr.Column(scale=1):
399
- # 郵件提示輸入
400
- email_prompt_input = gr.Textbox(
401
- label="📝 郵件提示",
402
- placeholder="例如:寫一封感謝信,感謝對方在項目中的幫助",
403
- lines=5,
404
- value="寫一封專業的郵件,介紹我們的 AI 產品"
405
- )
406
-
407
- # 收件人輸入
408
- recipient_input = gr.Textbox(
409
- label="📮 收件人郵箱(僅支援 Gmail)",
410
- placeholder="recipient@gmail.com",
411
- lines=1
412
- )
413
-
414
- # 按鈕
415
- with gr.Row():
416
- generate_draft_btn = gr.Button("📝 生成郵件草稿", variant="primary", scale=1)
417
- clear_email_btn = gr.Button("🗑️ 清除", variant="secondary", scale=1)
418
-
419
- # 狀態顯示
420
- email_status_display = gr.Textbox(
421
- label="📊 狀態",
422
- value="等待操作...",
423
- interactive=False,
424
- lines=2
425
- )
426
-
427
- # 反思結果顯示
428
- email_reflection_display = gr.Textbox(
429
- label="🔍 AI 反思評估",
430
- value="等待生成郵件...",
431
- interactive=False,
432
- lines=8,
433
- visible=True
434
- )
435
-
436
- with gr.Column(scale=1):
437
- # 郵件主題(可編輯)
438
- email_subject_input = gr.Textbox(
439
- label="📌 郵件主題",
440
- placeholder="郵件主題將在這裡顯示���您可以編輯",
441
- lines=1,
442
- interactive=True
443
- )
444
-
445
- # 郵件正文(可編輯)
446
- email_body_input = gr.Textbox(
447
- label="📄 郵件正文(可編輯)",
448
- placeholder="郵件內容將在這裡顯示,您可以編輯",
449
- lines=15,
450
- interactive=True
451
- )
452
-
453
- # 發送按鈕
454
- send_draft_btn = gr.Button("📧 發送郵件", variant="primary", scale=1)
455
-
456
- # 發送結果顯示
457
- email_result_display = gr.Textbox(
458
- label="📊 發送結果",
459
- lines=5,
460
- interactive=False
461
- )
462
-
463
- # 事件處理函數
464
- def generate_draft(prompt, recipient):
465
- """生成郵件草稿(包含反思功能)"""
466
- if not prompt or not prompt.strip():
467
- return "❌ 請輸入郵件提示", "", "", "❌ 請輸入郵件提示", "❌ 請輸入郵件提示"
468
-
469
- if not recipient or not recipient.strip():
470
- return "❌ 請輸入收件人郵箱", "", "", "❌ 請輸入收件人郵箱", "❌ 請輸入收件人郵箱"
471
-
472
- # 驗證郵箱格式和 Gmail 限制
473
- if "@" not in recipient or "." not in recipient.split("@")[1]:
474
- return "❌ 郵箱格式不正確", "", "", "❌ 郵箱格式不正確,請輸入有效的郵箱地址", "❌ 郵箱格式不正確,請輸入有效的郵箱地址"
475
-
476
- # 驗證是否為 Gmail 郵箱
477
- recipient_lower = recipient.strip().lower()
478
- if not (recipient_lower.endswith("@gmail.com") or recipient_lower.endswith("@googlemail.com")):
479
- return "❌ 僅支援 Gmail 郵箱", "", "", "❌ 此工具僅支援 Gmail 郵箱(@gmail.com 或 @googlemail.com),請輸入 Gmail 郵箱地址", "❌ 此工具僅支援 Gmail 郵箱(@gmail.com 或 @googlemail.com),請輸入 Gmail 郵箱地址"
480
-
481
- try:
482
- status_msg = "🔄 正在生成郵件草稿..."
483
- reflection_msg = "🔄 正在生成郵件草稿..."
484
-
485
- # 生成郵件草稿(包含反思功能,會自動改進)
486
- subject, body, status, reflection_result, was_improved = generate_email_draft(
487
- prompt, recipient.strip(), enable_reflection=True
488
- )
489
-
490
- if subject and body:
491
- # 格式化反思結果顯示
492
- if reflection_result:
493
- # 計算反思輪數
494
- reflection_count = reflection_result.count("【第") if "【第" in reflection_result else 0
495
-
496
- if was_improved:
497
- if reflection_count > 1:
498
- reflection_display = (
499
- f"🔍 **AI 迭代反思評估結果**(共 {reflection_count} 輪)\n\n"
500
- f"{reflection_result}\n\n"
501
- f"✨ **已自動應用改進建議,經過 {reflection_count} 輪優化,當前顯示的是最終優化版本**"
502
- )
503
- else:
504
- reflection_display = (
505
- f"🔍 **AI 反思評估結果**\n\n"
506
- f"{reflection_result}\n\n"
507
- f"✨ **已自動應用改進建議,當前顯示的是優化後的版本**"
508
- )
509
- else:
510
- reflection_display = (
511
- f"🔍 **AI 反思評估結果**\n\n"
512
- f"{reflection_result}\n\n"
513
- f"✅ **郵件質量良好,無需改進**"
514
- )
515
- else:
516
- reflection_display = "⚠️ 反思功能未返回結果"
517
-
518
- return status, subject, body, "", reflection_display
519
- else:
520
- return status, "", "", status, "❌ 生成失敗,無法進行反思評估"
521
- except Exception as e:
522
- error_msg = f"❌ 發生錯誤:{str(e)}"
523
- print(f"Email Tool 錯誤:{e}")
524
- import traceback
525
- traceback.print_exc()
526
- return "❌ 發生錯誤", "", "", error_msg, f"❌ 發生錯誤:{str(e)}"
527
-
528
- def send_draft(recipient, subject, body):
529
- """發送已編輯的郵件草稿"""
530
- if not recipient or not recipient.strip():
531
- return "❌ 請輸入收件人郵箱", "❌ 請輸入收件人郵箱"
532
-
533
- if not subject or not subject.strip():
534
- return "❌ 請輸入郵件主題", "❌ 請輸入郵件主題"
535
-
536
- if not body or not body.strip():
537
- return "❌ 請輸入郵件內容", "❌ 請輸入郵件內容"
538
-
539
- # 驗證郵箱格式和 Gmail 限制
540
- if "@" not in recipient or "." not in recipient.split("@")[1]:
541
- return "❌ 郵箱格式不正確", "❌ 郵箱格式不正確,請輸入有效的郵箱地址"
542
-
543
- # 驗證是否為 Gmail 郵箱
544
- recipient_lower = recipient.strip().lower()
545
- if not (recipient_lower.endswith("@gmail.com") or recipient_lower.endswith("@googlemail.com")):
546
- return "❌ 僅支援 Gmail 郵箱", "❌ 此工具僅支援 Gmail 郵箱(@gmail.com 或 @googlemail.com),請輸入 Gmail 郵箱地址"
547
-
548
- try:
549
- status_msg = "🔄 正在發送郵件..."
550
-
551
- # 發送郵件
552
- result = send_email_draft(recipient.strip(), subject.strip(), body.strip())
553
-
554
- return "✅ 郵件已發送", result
555
- except Exception as e:
556
- error_msg = f"❌ 發送郵件時發生錯誤:{str(e)}"
557
- print(f"Email Tool 錯誤:{e}")
558
- import traceback
559
- traceback.print_exc()
560
- return "❌ 發生錯誤", error_msg
561
-
562
- def clear_email():
563
- """清除郵件相關輸入和輸出"""
564
- return "", "", "等待操作...", "", "", "等待生成郵件..."
565
-
566
- # 綁定事件
567
- generate_draft_btn.click(
568
- fn=generate_draft,
569
- inputs=[email_prompt_input, recipient_input],
570
- outputs=[email_status_display, email_subject_input, email_body_input, email_result_display, email_reflection_display]
571
- )
572
-
573
- send_draft_btn.click(
574
- fn=send_draft,
575
- inputs=[recipient_input, email_subject_input, email_body_input],
576
- outputs=[email_status_display, email_result_display]
577
- )
578
-
579
- clear_email_btn.click(
580
- fn=clear_email,
581
- outputs=[email_prompt_input, recipient_input, email_status_display, email_subject_input, email_body_input, email_result_display, email_reflection_display]
582
- )
583
-
584
- # 示例
585
- gr.Examples(
586
- examples=[
587
- ["寫一封感謝信,感謝對方在項目中的幫助和支持", "example@gmail.com"],
588
- ["邀請參加下週的產品發布會", "colleague@gmail.com"],
589
- ["詢問項目進度並提供更新", "partner@gmail.com"],
590
- ["發送會議記錄和後續行動項目", "team@gmail.com"]
591
- ],
592
- inputs=[email_prompt_input, recipient_input]
593
- )
594
-
595
- # 頁腳說明
596
- gr.Markdown(
597
- f"""
598
- ---
599
- **注意事項:**
600
- 1. 使用 Gmail API 發送郵件,避免被歸類為垃圾郵件
601
- 2. **此工具僅支援 Gmail 郵箱,收件人必須使用 @gmail.com 或 @googlemail.com 結尾的郵箱地址**
602
- 3. 首次使用需要在專案根目錄放置 `credentials.json`(從 Google Cloud Console 下載的 OAuth2 憑證)
603
- 4. 首次運行時會自動開啟瀏覽器進行授權,授權後會生成 `token.json` 文件
604
- 5. 郵件內容由 AI 自動生成,請在發送前檢查結果
605
- 6. 寄件者固定為:{EMAIL_SENDER}
606
-
607
- **設置步驟:**
608
- - 前往 [Google Cloud Console](https://console.cloud.google.com/) 創建專案
609
- - 啟用 Gmail API
610
- - 創建 OAuth2 憑證並下載為 `credentials.json`
611
- - 將 `credentials.json` 放在專案根目錄
612
- """
613
- )
614
 
615
 
616
- def _create_calendar_interface():
617
- """創建 Calendar Tool 界面"""
618
- from ..agents.calendar_agent import generate_calendar_draft, create_calendar_draft
619
- from datetime import datetime, timedelta
620
-
621
- gr.Markdown(
622
- """
623
- ### 📅 智能行事曆管理助手
624
-
625
- 使用 AI 根據您的完整提示自動生成行事曆事件草稿,您可以在創建前檢查和修改。
626
-
627
- **使用方式:**
628
- 1. **快速選擇**:點擊下方常見事件按鈕,自動生成草稿
629
- 2. **自定義輸入**:在下方輸入完整的事件提示,包含:事件、日期、時間、地點、參與者
630
- 3. 查看 AI 反思評估結果和改進建議(如有)
631
- 4. 如果有缺失的資訊(如時間),系統會顯示下拉選單讓您選擇
632
- 5. 檢查並修改生成的事件內容
633
- 6. 確認無誤後點擊「創建事件」按鈕
634
-
635
- **✨ 新功能:AI 迭代反思評估 + Google Maps 地點驗證**
636
- - 系統會自動進行多輪反思評估(最多 3 輪)
637
- - 自動驗證並標準化地址,計算交通時間
638
- - 每輪評估後,如果有改進建議,會自動生成改進版本
639
- - 改進後的版本會再次評估,直到 AI 認為滿意為止
640
- """
641
- )
642
-
643
- # 快速選擇按鈕區域
644
- gr.Markdown("### 🚀 快速選擇常見事件")
645
- with gr.Row():
646
- quick_meeting_btn = gr.Button("📋 團隊會議", variant="secondary", scale=1)
647
- quick_client_btn = gr.Button("🤝 客戶拜訪", variant="secondary", scale=1)
648
- quick_lunch_btn = gr.Button("🍽️ 午餐會議", variant="secondary", scale=1)
649
- quick_oneonone_btn = gr.Button("💬 一對一會議", variant="secondary", scale=1)
650
- with gr.Row():
651
- quick_project_btn = gr.Button("📊 項目討論", variant="secondary", scale=1)
652
- quick_training_btn = gr.Button("🎓 培訓/學習", variant="secondary", scale=1)
653
- quick_social_btn = gr.Button("🎉 社交活動", variant="secondary", scale=1)
654
- quick_custom_btn = gr.Button("✏️ 自定義輸入", variant="secondary", scale=1)
655
-
656
- with gr.Row():
657
- with gr.Column(scale=1):
658
- # 單一 prompt 輸入
659
- calendar_prompt_input = gr.Textbox(
660
- label="📝 事件提示(包含事件、日期、時間、地點、參與者)",
661
- placeholder="例如:明天下午2點團隊會議,討論項目進度,地點在會議室A,參與者包括john@example.com和mary@example.com",
662
- lines=5,
663
- value=""
664
- )
665
-
666
- # 按鈕
667
- with gr.Row():
668
- generate_draft_btn = gr.Button("📝 生成事件草稿", variant="primary", scale=1)
669
- clear_calendar_btn = gr.Button("🗑️ 清除", variant="secondary", scale=1)
670
-
671
- # 狀態顯示
672
- calendar_status_display = gr.Textbox(
673
- label="📊 狀態",
674
- value="等待操作...",
675
- interactive=False,
676
- lines=2
677
- )
678
-
679
- # 反思結果顯示
680
- calendar_reflection_display = gr.Textbox(
681
- label="🔍 AI 反思評估",
682
- value="等待生成事件...",
683
- interactive=False,
684
- lines=8,
685
- visible=True
686
- )
687
-
688
- # 缺失資訊的補充區域(動態顯示)
689
- missing_info_group = gr.Group(visible=False)
690
- with missing_info_group:
691
- gr.Markdown("**⚠️ 請補充以下缺失的資訊:**")
692
-
693
- # 日期選擇(如果缺失)
694
- missing_date_display = gr.Dropdown(
695
- label="📆 選擇日期",
696
- choices=[],
697
- visible=False,
698
- interactive=True
699
- )
700
-
701
- # 時間選擇(如果缺失)
702
- missing_time_display = gr.Dropdown(
703
- label="🕐 選擇時間",
704
- choices=[],
705
- visible=False,
706
- interactive=True
707
- )
708
-
709
- fill_missing_btn = gr.Button("✅ 確認補充資訊", variant="primary", visible=False)
710
-
711
- # 隱藏狀態變數,用於存儲 event_dict
712
- event_dict_storage = gr.State(value={})
713
-
714
- with gr.Column(scale=1):
715
- # 事件詳情顯示和編輯區域
716
- event_summary_display = gr.Textbox(
717
- label="📌 事件標題",
718
- placeholder="事件標題將在這裡顯示",
719
- lines=1,
720
- interactive=True
721
- )
722
-
723
- event_start_display = gr.Textbox(
724
- label="🕐 開始時間",
725
- placeholder="開始時間將在這裡顯示(格式: YYYY-MM-DDTHH:MM:SS+08:00)",
726
- lines=1,
727
- interactive=True
728
- )
729
-
730
- event_end_display = gr.Textbox(
731
- label="🕐 結束時間",
732
- placeholder="結束時間將在這裡顯示(格式: YYYY-MM-DDTHH:MM:SS+08:00)",
733
- lines=1,
734
- interactive=True
735
- )
736
-
737
- event_description_display = gr.Textbox(
738
- label="📄 事件描述(可編輯)",
739
- placeholder="事件描述將在這裡顯示,您可以編輯",
740
- lines=6,
741
- interactive=True
742
- )
743
-
744
- event_location_display = gr.Textbox(
745
- label="📍 地點(可編輯,已自動驗證並標準化)",
746
- placeholder="事件地點將在這裡顯示,您可以編輯",
747
- lines=2,
748
- interactive=True
749
- )
750
-
751
- event_attendees_display = gr.Textbox(
752
- label="👥 參與者郵箱(可編輯,多個用逗號分隔)",
753
- placeholder="參與者郵箱將在這裡顯示,您可以編輯",
754
- lines=1,
755
- interactive=True
756
- )
757
-
758
- # 創建按鈕
759
- create_event_btn = gr.Button("✅ 創建事件", variant="primary", scale=1)
760
-
761
- # 操作結果顯示
762
- calendar_result_display = gr.Textbox(
763
- label="📊 操作結果",
764
- lines=8,
765
- interactive=False
766
- )
767
-
768
- # 生成時間選項(每30分鐘一個選項)
769
- def generate_time_options():
770
- """生成時間選項列表"""
771
- times = []
772
- for hour in range(24):
773
- for minute in [0, 30]:
774
- time_str = f"{hour:02d}:{minute:02d}"
775
- times.append(time_str)
776
- return times
777
-
778
- # 生成日期選項(今天、明天、後天,以及未來7天)
779
- def generate_date_options():
780
- """生成日期選項列表"""
781
- dates = []
782
- today = datetime.now()
783
- date_names = ["今天", "明天", "後天"]
784
-
785
- for i in range(3):
786
- date_obj = today + timedelta(days=i)
787
- date_str = date_obj.strftime('%Y-%m-%d')
788
- dates.append(f"{date_names[i]} ({date_str})")
789
-
790
- for i in range(3, 7):
791
- date_obj = today + timedelta(days=i)
792
- date_str = date_obj.strftime('%Y-%m-%d')
793
- dates.append(date_str)
794
-
795
- return dates
796
-
797
- # 快速選擇事件模板生成函數
798
- def generate_quick_prompt(event_type: str) -> str:
799
- """根據事件類型生成預設提示"""
800
- from datetime import datetime, timedelta
801
-
802
- # 獲取明天的日期
803
- tomorrow = datetime.now() + timedelta(days=1)
804
- tomorrow_str = tomorrow.strftime("%Y-%m-%d")
805
-
806
- templates = {
807
- "meeting": f"明天下午2點團隊會議,討論項目進度和下週計劃,地點在會議室,參與者包括團隊成員",
808
- "client": f"明天上午10點客戶拜訪,討論合作方案和需求,地點在客戶公司或會議室",
809
- "lunch": f"明天中午12點午餐會議,與合作夥伴討論業務合作,地點在附近的餐廳",
810
- "oneonone": f"明天下午3點一對一會議,討論工作進展和職業發展,地點在會議室或咖啡廳",
811
- "project": f"明天上午9點項目討論會議,審查項目進度和解決問題,地點在項目室,參與者包括項目團隊",
812
- "training": f"明天下午2點培訓課程,學習新技能和最佳實踐,地點在培訓室或線上",
813
- "social": f"明天晚上6點團隊聚餐,慶祝項目完成,地點在餐廳,參與者包括團隊成員",
814
- "custom": "" # 自定義,返回空讓用戶輸入
815
- }
816
-
817
- return templates.get(event_type, "")
818
-
819
- # 快速選擇按鈕處理函數(自動生成草稿)
820
- def quick_select_and_generate(event_type: str):
821
- """快速選擇事件類型並自動生成草稿"""
822
- prompt = generate_quick_prompt(event_type)
823
- if not prompt:
824
- # 如果是自定義,只返回空提示,不自動生成
825
- return (
826
- prompt, # calendar_prompt_input
827
- "請在下方輸入框中輸入事件提示,然後點擊「生成事件草稿」", # calendar_status_display
828
- "等待輸入...", # calendar_reflection_display
829
- gr.update(visible=False), # missing_info_group
830
- gr.update(visible=False, choices=[]), # missing_date_display
831
- gr.update(visible=False, choices=[]), # missing_time_display
832
- gr.update(visible=False), # fill_missing_btn
833
- "", "", "", "", "", "", # event fields
834
- {}, # event_dict_storage
835
- "" # calendar_result_display
836
- )
837
-
838
- # 自動生成草稿(調用 generate_draft 並返回所有輸出)
839
- draft_result = generate_draft(prompt)
840
- # generate_draft 返回的格式是:(status, reflection_display, missing_info_group, ...)
841
- # 但我們需要返回 (prompt, status, reflection_display, ...)
842
- # draft_result 是一個元組,我們需要將 prompt 添加到開頭
843
- return (prompt,) + draft_result
844
-
845
- def quick_select_meeting():
846
- """快速選擇:團隊會議"""
847
- return quick_select_and_generate("meeting")
848
-
849
- def quick_select_client():
850
- """快速選擇:客戶拜訪"""
851
- return quick_select_and_generate("client")
852
-
853
- def quick_select_lunch():
854
- """快速選擇:午餐會議"""
855
- return quick_select_and_generate("lunch")
856
-
857
- def quick_select_oneonone():
858
- """快速選擇:一對一會議"""
859
- return quick_select_and_generate("oneonone")
860
-
861
- def quick_select_project():
862
- """快速選擇:項目討論"""
863
- return quick_select_and_generate("project")
864
-
865
- def quick_select_training():
866
- """快速選擇:培訓/學習"""
867
- return quick_select_and_generate("training")
868
-
869
- def quick_select_social():
870
- """快速選擇:社交活動"""
871
- return quick_select_and_generate("social")
872
-
873
- def quick_select_custom():
874
- """快速選擇:自定義輸入(只清空,不自動生成)"""
875
- return (
876
- "", # calendar_prompt_input
877
- "請在下方輸入框中輸入事件提示,然後點擊「生成事件草稿」", # calendar_status_display
878
- "等待輸入...", # calendar_reflection_display
879
- gr.update(visible=False), # missing_info_group
880
- gr.update(visible=False, choices=[]), # missing_date_display
881
- gr.update(visible=False, choices=[]), # missing_time_display
882
- gr.update(visible=False), # fill_missing_btn
883
- "", "", "", "", "", "", # event fields
884
- {}, # event_dict_storage
885
- "" # calendar_result_display
886
- )
887
-
888
- # 事件處理函數
889
- def generate_draft(prompt):
890
- """生成行事曆事件草稿(包含反思功能)"""
891
- if not prompt or not prompt.strip():
892
- return (
893
- "❌ 請輸入事件提示",
894
- "❌ 請輸入事件提示",
895
- gr.update(visible=False),
896
- gr.update(visible=False, choices=[]),
897
- gr.update(visible=False, choices=[]),
898
- gr.update(visible=False),
899
- "", "", "", "", "", "", {},
900
- "❌ 請輸入事件提示"
901
- )
902
-
903
- try:
904
- status_msg = "🔄 正在生成事件草稿..."
905
-
906
- # 生成事件草稿(包含反思功能)
907
- event_dict, status, missing_info, reflection_result, was_improved = generate_calendar_draft(
908
- prompt.strip(), enable_reflection=True
909
- )
910
-
911
- if not event_dict:
912
- return (
913
- status,
914
- gr.update(visible=False),
915
- gr.update(visible=False, choices=[]),
916
- gr.update(visible=False, choices=[]),
917
- gr.update(visible=False),
918
- "", "", "", "", "", "", "",
919
- status
920
- )
921
-
922
- # 格式化反思結果顯示
923
- if reflection_result:
924
- # 計算反思輪數
925
- reflection_count = reflection_result.count("【第") if "【第" in reflection_result else 0
926
-
927
- if was_improved:
928
- if reflection_count > 1:
929
- reflection_display = (
930
- f"🔍 **AI 迭代反思評估結果**(共 {reflection_count} 輪)\n\n"
931
- f"{reflection_result}\n\n"
932
- f"✨ **已自動應用改進建議,經過 {reflection_count} 輪優化,當前顯示的是最終優化版本**"
933
- )
934
- else:
935
- reflection_display = (
936
- f"🔍 **AI 反思評估結果**\n\n"
937
- f"{reflection_result}\n\n"
938
- f"✨ **已自動應用改進建議,當前顯示的是優化後的版本**"
939
- )
940
- else:
941
- reflection_display = (
942
- f"🔍 **AI 反思評估結果**\n\n"
943
- f"{reflection_result}\n\n"
944
- f"✅ **事件質量良好,無需改進**"
945
- )
946
- else:
947
- reflection_display = "⚠️ 反思功能未返回結果"
948
-
949
- # 【Google Maps 整合】添加地點建議訊息
950
- location_suggestion = event_dict.get("location_suggestion", "")
951
- if location_suggestion:
952
- # 將地點建議添加到狀態訊息中
953
- if status:
954
- status = f"{status}\n\n🗺️ **地點資訊:**\n{location_suggestion}"
955
- else:
956
- status = f"🗺️ **地點資訊:**\n{location_suggestion}"
957
-
958
- # 檢查是否有缺失資訊
959
- has_missing = bool(missing_info)
960
-
961
- if has_missing:
962
- # 顯示缺失資訊區域
963
- date_visible = missing_info.get("date", False)
964
- time_visible = missing_info.get("time", False)
965
-
966
- date_choices = generate_date_options() if date_visible else []
967
- time_choices = generate_time_options() if time_visible else []
968
-
969
- return (
970
- status,
971
- reflection_display,
972
- gr.update(visible=True), # 顯示缺失資訊區域
973
- gr.update(visible=date_visible, choices=date_choices, value=date_choices[0] if date_choices else None),
974
- gr.update(visible=time_visible, choices=time_choices, value=time_choices[0] if time_choices else None),
975
- gr.update(visible=True), # 顯示確認按鈕
976
- event_dict.get("summary", ""),
977
- event_dict.get("start_datetime", ""),
978
- event_dict.get("end_datetime", ""),
979
- event_dict.get("description", ""),
980
- event_dict.get("location", ""),
981
- event_dict.get("attendees", ""),
982
- event_dict, # 傳遞完整的事件字典以便後續使用
983
- ""
984
- )
985
- else:
986
- # 沒有缺失資訊,直接顯示結果
987
- return (
988
- status,
989
- reflection_display,
990
- gr.update(visible=False),
991
- gr.update(visible=False, choices=[]),
992
- gr.update(visible=False, choices=[]),
993
- gr.update(visible=False),
994
- event_dict.get("summary", ""),
995
- event_dict.get("start_datetime", ""),
996
- event_dict.get("end_datetime", ""),
997
- event_dict.get("description", ""),
998
- event_dict.get("location", ""),
999
- event_dict.get("attendees", ""),
1000
- event_dict,
1001
- ""
1002
- )
1003
- except Exception as e:
1004
- error_msg = f"❌ 發生錯誤:{str(e)}"
1005
- print(f"Calendar Tool 錯誤:{e}")
1006
- import traceback
1007
- traceback.print_exc()
1008
- return (
1009
- "❌ 發生錯誤",
1010
- f"❌ 發生錯誤:{str(e)}",
1011
- gr.update(visible=False),
1012
- gr.update(visible=False, choices=[]),
1013
- gr.update(visible=False, choices=[]),
1014
- gr.update(visible=False),
1015
- "", "", "", "", "", "", {},
1016
- error_msg
1017
- )
1018
-
1019
- def fill_missing_info(event_dict_storage, selected_date, selected_time):
1020
- """填充缺失的資訊"""
1021
- if not event_dict_storage:
1022
- return (
1023
- "❌ 沒有事件資料",
1024
- gr.update(visible=False),
1025
- gr.update(visible=False, choices=[]),
1026
- gr.update(visible=False, choices=[]),
1027
- gr.update(visible=False),
1028
- "", "", "", "", "", "",
1029
- {}
1030
- )
1031
-
1032
- # 更新日期和時間
1033
- if selected_date:
1034
- # 從選項中提取日期字串(例如:"明天 (2026-01-25)" -> "2026-01-25")
1035
- if "(" in selected_date:
1036
- date_str = selected_date.split("(")[1].split(")")[0]
1037
- else:
1038
- date_str = selected_date
1039
- else:
1040
- date_str = event_dict_storage.get("date", "今天")
1041
-
1042
- if selected_time:
1043
- time_str = selected_time
1044
- else:
1045
- time_str = "09:00" # 預設時間
1046
-
1047
- # 重新解析日期和時間
1048
- from ..agents.calendar_agent import parse_datetime
1049
- start_datetime, end_datetime = parse_datetime(date_str, time_str)
1050
-
1051
- # 更新事件字典
1052
- event_dict_storage["start_datetime"] = start_datetime
1053
- event_dict_storage["end_datetime"] = end_datetime
1054
-
1055
- return (
1056
- "✅ 資訊已補充,請檢查並創建事件",
1057
- gr.update(visible=False), # 隱藏缺失資訊區域
1058
- gr.update(visible=False, choices=[]),
1059
- gr.update(visible=False, choices=[]),
1060
- gr.update(visible=False),
1061
- event_dict_storage.get("summary", ""),
1062
- start_datetime,
1063
- end_datetime,
1064
- event_dict_storage.get("description", ""),
1065
- event_dict_storage.get("location", ""),
1066
- event_dict_storage.get("attendees", ""),
1067
- event_dict_storage
1068
- )
1069
-
1070
- def create_event(summary, start_datetime, end_datetime, description, location, attendees):
1071
- """創建行事曆事件"""
1072
- if not summary or not summary.strip():
1073
- return "❌ 請輸入事件標題", "❌ 請輸入事件標題"
1074
-
1075
- if not start_datetime or not start_datetime.strip():
1076
- return "❌ 請輸入開始時間", "❌ 請輸入開始時間"
1077
-
1078
- if not end_datetime or not end_datetime.strip():
1079
- return "❌ 請輸入結束時間", "❌ 請輸入結束時間"
1080
-
1081
- try:
1082
- status_msg = "🔄 正在創建事件..."
1083
-
1084
- # 構建事件字典
1085
- event_dict = {
1086
- "summary": summary.strip(),
1087
- "start_datetime": start_datetime.strip(),
1088
- "end_datetime": end_datetime.strip(),
1089
- "description": description.strip() if description else "",
1090
- "location": location.strip() if location else "",
1091
- "attendees": attendees.strip() if attendees else "",
1092
- "timezone": "Asia/Taipei"
1093
- }
1094
-
1095
- # 創建事件
1096
- result = create_calendar_draft(event_dict)
1097
-
1098
- return "✅ 事件已創建", result
1099
- except Exception as e:
1100
- error_msg = f"❌ 創建事件時發生錯誤:{str(e)}"
1101
- print(f"Calendar Tool 錯誤:{e}")
1102
- import traceback
1103
- traceback.print_exc()
1104
- return "❌ 發生錯誤", error_msg
1105
-
1106
- def clear_calendar():
1107
- """清除行事曆相關輸入和輸出"""
1108
- return (
1109
- "", # prompt
1110
- "等待操作...", # status
1111
- "等待生成事件...", # reflection_display
1112
- gr.update(visible=False), # missing_info_group
1113
- gr.update(visible=False, choices=[]), # missing_date
1114
- gr.update(visible=False, choices=[]), # missing_time
1115
- gr.update(visible=False), # fill_missing_btn
1116
- "", "", "", "", "", "", # event fields
1117
- {}, # event_dict_storage
1118
- "" # result
1119
- )
1120
-
1121
- # 綁定事件
1122
- generate_draft_btn.click(
1123
- fn=generate_draft,
1124
- inputs=[calendar_prompt_input],
1125
- outputs=[
1126
- calendar_status_display,
1127
- calendar_reflection_display,
1128
- missing_info_group,
1129
- missing_date_display,
1130
- missing_time_display,
1131
- fill_missing_btn,
1132
- event_summary_display,
1133
- event_start_display,
1134
- event_end_display,
1135
- event_description_display,
1136
- event_location_display,
1137
- event_attendees_display,
1138
- event_dict_storage,
1139
- calendar_result_display
1140
- ]
1141
- )
1142
-
1143
- # 綁定快速選擇按鈕(自動填充提示並生成草稿)
1144
- quick_outputs = [
1145
- calendar_prompt_input, # 更新提示輸入框
1146
- calendar_status_display,
1147
- calendar_reflection_display,
1148
- missing_info_group,
1149
- missing_date_display,
1150
- missing_time_display,
1151
- fill_missing_btn,
1152
- event_summary_display,
1153
- event_start_display,
1154
- event_end_display,
1155
- event_description_display,
1156
- event_location_display,
1157
- event_attendees_display,
1158
- event_dict_storage,
1159
- calendar_result_display
1160
- ]
1161
-
1162
- quick_meeting_btn.click(fn=quick_select_meeting, outputs=quick_outputs)
1163
- quick_client_btn.click(fn=quick_select_client, outputs=quick_outputs)
1164
- quick_lunch_btn.click(fn=quick_select_lunch, outputs=quick_outputs)
1165
- quick_oneonone_btn.click(fn=quick_select_oneonone, outputs=quick_outputs)
1166
- quick_project_btn.click(fn=quick_select_project, outputs=quick_outputs)
1167
- quick_training_btn.click(fn=quick_select_training, outputs=quick_outputs)
1168
- quick_social_btn.click(fn=quick_select_social, outputs=quick_outputs)
1169
- quick_custom_btn.click(fn=quick_select_custom, outputs=quick_outputs)
1170
-
1171
- fill_missing_btn.click(
1172
- fn=fill_missing_info,
1173
- inputs=[event_dict_storage, missing_date_display, missing_time_display],
1174
- outputs=[
1175
- calendar_status_display,
1176
- missing_info_group,
1177
- missing_date_display,
1178
- missing_time_display,
1179
- fill_missing_btn,
1180
- event_summary_display,
1181
- event_start_display,
1182
- event_end_display,
1183
- event_description_display,
1184
- event_location_display,
1185
- event_attendees_display,
1186
- event_dict_storage
1187
- ]
1188
- )
1189
-
1190
- create_event_btn.click(
1191
- fn=create_event,
1192
- inputs=[
1193
- event_summary_display,
1194
- event_start_display,
1195
- event_end_display,
1196
- event_description_display,
1197
- event_location_display,
1198
- event_attendees_display
1199
- ],
1200
- outputs=[calendar_status_display, calendar_result_display]
1201
- )
1202
-
1203
- clear_calendar_btn.click(
1204
- fn=clear_calendar,
1205
- outputs=[
1206
- calendar_prompt_input,
1207
- calendar_status_display,
1208
- calendar_reflection_display,
1209
- missing_info_group,
1210
- missing_date_display,
1211
- missing_time_display,
1212
- fill_missing_btn,
1213
- event_summary_display,
1214
- event_start_display,
1215
- event_end_display,
1216
- event_description_display,
1217
- event_location_display,
1218
- event_attendees_display,
1219
- event_dict_storage,
1220
- calendar_result_display
1221
- ]
1222
- )
1223
-
1224
- # 示例
1225
- gr.Examples(
1226
- examples=[
1227
- "明天下午2點團隊會議,討論項目進度,地點在會議室A,參與者包括john@example.com",
1228
- "2026-01-25 上午9點產品發布會,介紹新功能和改進,地點在總部大樓",
1229
- "後天下午3點客戶會議,討論合作細節,參與者包括客戶代表",
1230
- "下週一上午10點技術分享會,分享最新的 AI 技術,地點在研發中心"
1231
- ],
1232
- inputs=[calendar_prompt_input]
1233
- )
1234
-
1235
- # 頁腳說明
1236
- gr.Markdown(
1237
- """
1238
- ---
1239
- **注意事項:**
1240
- 1. 使用 Google Calendar API 管理行事曆事件
1241
- 2. 首次使用需要在專案根目錄放置 `credentials.json`(從 Google Cloud Console 下載的 OAuth2 憑證)
1242
- 3. 首次運行時會自動開啟瀏覽器進行授權,授權後會生成 `token.json` 文件
1243
- 4. 事件內容由 AI 自動生成,請在創建前檢查結果
1244
- 5. 在提示中包含所有資訊:事件、日期、時間、地點、參與者
1245
- 6. 如果缺少日期或時間,系統會顯示下拉選單讓您選擇
1246
- 7. 日期格式支援:YYYY-MM-DD(例如:2026-01-25)或相對日期(今天��明天、後天)
1247
- 8. 時間格式支援:24小時制(14:00)或12小時制(2:00 PM)
1248
-
1249
- **設置步驟:**
1250
- - 前往 [Google Cloud Console](https://console.cloud.google.com/) 創建專案
1251
- - 啟用 Google Calendar API
1252
- - 創建 OAuth2 憑證並下載為 `credentials.json`
1253
- - 將 `credentials.json` 放在專案根目錄
1254
- - 確保授予 Calendar API 的完整存取權限
1255
- """
1256
- )
1257
 
 
 
 
5
  import uuid
6
  import re
7
  import time
8
+ import json
9
+ import os
10
  from typing import Iterator, Tuple
11
  import gradio as gr
12
  from langchain_core.messages import HumanMessage
13
 
14
  # graph 和 rag_retriever 將從外部傳入,不在這裡導入
15
  from ..utils.llm_utils import get_llm_type, is_using_local_llm
16
+ from .email_interface import _create_email_interface
17
+ from .calendar_interface import _create_calendar_interface
18
+ from .private_file_rag_interface import _create_private_file_rag_interface
19
 
20
 
21
  def run_research_agent(query: str, graph, thread_id: str = None) -> Iterator[Tuple[str, str, str, str, str]]:
 
209
  gr.Markdown(
210
  """
211
  <div class="header">
212
+ <h1>🚀 Deep Research Agent with RAG</h1>
213
  <p><strong>功能特色:</strong></p>
214
+ <p>📊 股票資訊查詢 | 🌐 網路搜尋 | 📚 PDF 知識庫查詢(Tree of Thoughts 論文)| 📧 智能郵件助手 | 📅 智能行事曆管理 | 📄 私有文件 RAG 問答</p>
215
  <p><strong>智能規劃:</strong> 系統會根據問題類型自動選擇合適的研究工具</p>
216
  <p><strong>本地模型:</strong> 使用 MLX 本地模型,保護隱私,無需 API 金鑰</p>
217
  </div>
 
232
  # Tab 3: Calendar Tool
233
  with gr.Tab("📅 Calendar Tool"):
234
  _create_calendar_interface()
235
+
236
+ # Tab 4: Private File RAG
237
+ with gr.Tab("📚 Private File RAG"):
238
+ _create_private_file_rag_interface()
239
 
240
  return demo
241
 
 
372
  )
373
 
374
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
 
376
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
 
378
+
379
+
deep_agent_rag/ui/private_file_rag_interface.py ADDED
@@ -0,0 +1,663 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # deep_agent_rag/ui/private_file_rag_interface.py
2
+
3
+ import gradio as gr
4
+ import re
5
+ import json
6
+ import os
7
+ import time
8
+
9
+ from ..rag.private_file_rag import get_private_rag_instance, reset_private_rag_instance
10
+ # Assuming is_using_local_llm might be used for warnings/status, similar to email_interface
11
+ # from ..utils.llm_utils import is_using_local_llm
12
+
13
+ # Agent log path for debugging (if needed)
14
+ log_path = "/Users/matthuang/Desktop/Deep_Agentic_AI_Tool/.cursor/debug.log"
15
+
16
+ def _create_private_file_rag_interface():
17
+ """創建私有文件 RAG 界面(對話式 Chatbot)"""
18
+ gr.Markdown(
19
+ """
20
+ ### 📚 私有文件 RAG 對話系統
21
+
22
+ 上傳您的私有文件(PDF、DOCX、TXT),系統會自動建立 RAG 知識庫,讓 AI 可以回答關於這些文件的問題。
23
+ 支持多輪對話,AI 會記住之前的對話內容,提供更連貫的回答。
24
+
25
+ **使用方式:**
26
+ 1. 上傳一個或多個文件(PDF、DOCX、TXT)
27
+ 2. 點擊「處理文件」按鈕,系統會自動處理文件並建立 RAG 系統
28
+ 3. 在對話框中輸入您的問題,按 Enter 或點擊「發送」按鈕
29
+ 4. AI 會基於上傳的文件回答問題,支持多輪對話
30
+
31
+ **功能特色:**
32
+ - 💬 **對話式界面** :類似 Gemini 的對話體驗,支持多輪對話
33
+ - 📄 支持多種文件格式:PDF、DOCX、TXT
34
+ - 🔍 使用混合搜尋(BM25 + 向量檢索)提升檢索準確度
35
+ - 🎯 可選重排序功能,進一步優化結果
36
+ - 🧠 支持語義分塊,保持語義完整性
37
+ - 🌐 自動檢測文檔類型並調整回答風格
38
+
39
+ **LLM 使用策略:**
40
+ - 🥇 **優先使用 Groq API** :如果配置了 API 金鑰,優先使用 Groq(速度快、質量高)
41
+ - 🥈 **其次使用 Ollama** :如果 Groq 不可用,自動切換到 Ollama 本地模型
42
+ - 🥉 **最後使用 MLX** :如果前兩者都不可用,使用 MLX 本地模型作為備選
43
+ - 💡 **自動切換** :系統會根據 API 額度、服務狀態等自動選擇最合適的 LLM
44
+
45
+ **注意:** 此功能需要 Learn_RAG 項目在正確的位置
46
+ """
47
+ )
48
+
49
+ # 對話歷史狀態
50
+ chat_history = gr.State(value=[])
51
+
52
+ with gr.Row():
53
+ # 左側:文件上傳和設置
54
+ with gr.Column(scale=1):
55
+ # 文件上傳區域
56
+ file_upload = gr.File(
57
+ label="📁 上傳文件(PDF、DOCX、TXT)",
58
+ file_count="multiple",
59
+ file_types=[ ".pdf", ".docx", ".doc", ".txt"]
60
+ )
61
+
62
+ # 處理按鈕
63
+ with gr.Row():
64
+ process_btn = gr.Button("📝 處理文件", variant="primary", scale=1)
65
+ clear_files_btn = gr.Button("🗑️ 清除所有", variant="secondary", scale=1)
66
+
67
+ # 處理狀態
68
+ process_status = gr.Textbox(
69
+ label="📊 處理狀態",
70
+ value="等待上傳文件...",
71
+ interactive=False,
72
+ lines=2
73
+ )
74
+
75
+ # 設置區域(使用 Accordion 摺疊)
76
+ with gr.Accordion("⚙️ 進階設置", open=False):
77
+ # 處理選項
78
+ use_semantic_chunking = gr.Checkbox(
79
+ label="使用語義分塊(推薦)",
80
+ value=False,
81
+ info="語義分塊能保持語義完整性,但處理時間較長"
82
+ )
83
+
84
+ # 分塊參數調整(字符分塊模式)
85
+ gr.Markdown("**📏 字符分塊參數(僅在未使用語義分塊時有效)**")
86
+ chunk_size_slider = gr.Slider(
87
+ minimum=200,
88
+ maximum=1500,
89
+ value=500,
90
+ step=50,
91
+ label="分塊大小(字符數)",
92
+ info="建議:300-800"
93
+ )
94
+ chunk_overlap_slider = gr.Slider(
95
+ minimum=0,
96
+ maximum=300,
97
+ value=100,
98
+ step=25,
99
+ label="分塊重疊(字符數)",
100
+ info="建議:chunk_size 的 15-25%"
101
+ )
102
+
103
+ # 語義分塊參數調整(僅在使用語義分塊時有效)
104
+ gr.Markdown("**🔬 語義分塊參數(僅在使用語義分塊時有效)**")
105
+ semantic_threshold_slider = gr.Slider(
106
+ minimum=0.5,
107
+ maximum=2.5,
108
+ value=1.0,
109
+ step=0.1,
110
+ label="語義分塊閾值(敏感度)",
111
+ info="建議:0.8-1.2(細粒度)"
112
+ )
113
+ semantic_min_chunk_slider = gr.Slider(
114
+ minimum=50,
115
+ maximum=300,
116
+ value=100,
117
+ step=25,
118
+ label="最小分塊大小(字符數)",
119
+ info="建議:50-200"
120
+ )
121
+
122
+ # RAG 方法選擇
123
+ gr.Markdown("**🎯 RAG 方法選擇**")
124
+ enable_adaptive_selection = gr.Checkbox(
125
+ label="自動選擇最佳 RAG 方法(推薦)",
126
+ value=True,
127
+ info="系統會根據查詢和文件特征自動選擇最合適的 RAG 方法"
128
+ )
129
+ manual_rag_method = gr.Dropdown(
130
+ choices=[
131
+ "basic",
132
+ "subquery",
133
+ "hyde",
134
+ "step_back",
135
+ "hybrid_subquery_hyde",
136
+ "triple_hybrid"
137
+ ],
138
+ value="basic",
139
+ label="手動選擇 RAG 方法",
140
+ info="僅在自動選擇關閉時生效",
141
+ visible=False
142
+ )
143
+
144
+ # 查詢選項
145
+ top_k_slider = gr.Slider(
146
+ minimum=1,
147
+ maximum=10,
148
+ value=3,
149
+ step=1,
150
+ label="返回結果數量"
151
+ )
152
+ use_llm_checkbox = gr.Checkbox(
153
+ label="使用 LLM 生成回答",
154
+ value=True
155
+ )
156
+
157
+ # 右側:對話界面
158
+ with gr.Column(scale=2):
159
+ # Chatbot 組件
160
+ # #region agent log
161
+ try:
162
+ with open(log_path, "a", encoding="utf-8") as f:
163
+ log_entry = {
164
+ "sessionId": "debug-session",
165
+ "runId": "run1",
166
+ "hypothesisId": "A",
167
+ "location": "private_file_rag_interface.py:1409", # Adjusted line number
168
+ "message": "Before Chatbot creation",
169
+ "data": {
170
+ "gradio_version": gr.__version__ if hasattr(gr, '__version__') else "unknown"
171
+ },
172
+ "timestamp": int(time.time() * 1000)
173
+ }
174
+ f.write(json.dumps(log_entry, ensure_ascii=False) + "\n")
175
+ except:
176
+ pass
177
+ # #endregion
178
+
179
+ # 創建 Chatbot(移除不支持的參數:show_copy_button 和 avatar_images)
180
+ # #region agent log
181
+ try:
182
+ with open(log_path, "a", encoding="utf-8") as f:
183
+ log_entry = {
184
+ "sessionId": "debug-session",
185
+ "runId": "run1",
186
+ "hypothesisId": "A",
187
+ "location": "private_file_rag_interface.py:1430", # Adjusted line number
188
+ "message": "Creating Chatbot with minimal params",
189
+ "data": {"params": ["label", "height"]},
190
+ "timestamp": int(time.time() * 1000)
191
+ }
192
+ f.write(json.dumps(log_entry, ensure_ascii=False) + "\n")
193
+ except:
194
+ pass
195
+ # #endregion
196
+
197
+ try:
198
+ chatbot = gr.Chatbot(
199
+ label="💬 對話",
200
+ height=500
201
+ )
202
+ # #region agent log
203
+ try:
204
+ with open(log_path, "a", encoding="utf-8") as f:
205
+ log_entry = {
206
+ "sessionId": "debug-session",
207
+ "runId": "run1",
208
+ "hypothesisId": "A",
209
+ "location": "private_file_rag_interface.py:1448", # Adjusted line number
210
+ "message": "Chatbot created successfully",
211
+ "data": {"success": True},
212
+ "timestamp": int(time.time() * 1000)
213
+ }
214
+ f.write(json.dumps(log_entry, ensure_ascii=False) + "\n")
215
+ except:
216
+ pass
217
+ # #endregion
218
+ except Exception as e:
219
+ # #region agent log
220
+ try:
221
+ with open(log_path, "a", encoding="utf-8") as f:
222
+ log_entry = {
223
+ "sessionId": "debug-session",
224
+ "runId": "run1",
225
+ "hypothesisId": "A",
226
+ "location": "private_file_rag_interface.py:1460", # Adjusted line number
227
+ "message": "Chatbot creation failed",
228
+ "data": {
229
+ "error_type": type(e).__name__,
230
+ "error_message": str(e)
231
+ },
232
+ "timestamp": int(time.time() * 1000)
233
+ }
234
+ f.write(json.dumps(log_entry, ensure_ascii=False) + "\n")
235
+ except:
236
+ pass
237
+ # #endregion
238
+ raise
239
+
240
+ # 輸入框
241
+ msg = gr.Textbox(
242
+ label="輸入問題",
243
+ placeholder="輸入您的問題,按 Enter 發送...",
244
+ lines=2,
245
+ scale=4
246
+ )
247
+
248
+ # 按鈕區域
249
+ with gr.Row():
250
+ submit_btn = gr.Button("📤 發送", variant="primary", scale=1)
251
+ clear_chat_btn = gr.Button("🗑️ 清除對話", variant="secondary", scale=1)
252
+
253
+ # 查詢狀態
254
+ query_status = gr.Textbox(
255
+ label="📊 狀態",
256
+ value="等待查詢...",
257
+ interactive=False,
258
+ lines=1
259
+ )
260
+
261
+ # 輔助函數:轉換 Gradio 歷史格式(dict)和 RAG 歷史格式(tuple)
262
+ def history_dict_to_tuple(history_dict):
263
+ """
264
+ 將 Gradio 歷史格式(List[Dict])轉換為 RAG 歷史格式(List[Tuple[str, str]])
265
+
266
+ Args:
267
+ history_dict: Gradio 格式的歷史,每個元素為 {"role": "user"/"assistant", "content": "..."}
268
+
269
+ Returns:
270
+ RAG 格式的歷史,每個元素為 (user_message, assistant_message)
271
+ """
272
+ if not history_dict:
273
+ return []
274
+
275
+ conversation_history = []
276
+ current_user_msg = None
277
+
278
+ for msg in history_dict:
279
+ if isinstance(msg, dict):
280
+ role = msg.get("role", "")
281
+ content = msg.get("content", "")
282
+
283
+ if role == "user":
284
+ current_user_msg = content
285
+ elif role == "assistant" and current_user_msg is not None:
286
+ conversation_history.append((current_user_msg, content))
287
+ current_user_msg = None
288
+ elif isinstance(msg, tuple) and len(msg) == 2:
289
+ # 如果已經是 tuple 格式,直接使用(向後兼容)
290
+ conversation_history.append(msg)
291
+
292
+ return conversation_history
293
+
294
+ def history_tuple_to_dict(history_tuple):
295
+ """
296
+ 將 RAG 歷史格式(List[Tuple[str, str]])轉換為 Gradio 歷史格式(List[Dict])
297
+
298
+ Args:
299
+ history_tuple: RAG 格式的歷史,每個元素為 (user_message, assistant_message)
300
+
301
+ Returns:
302
+ Gradio 格式的歷史,每個元素為 {"role": "user"/"assistant", "content": "..."}
303
+ """
304
+ if not history_tuple:
305
+ return []
306
+
307
+ history_dict = []
308
+ for msg in history_tuple:
309
+ if isinstance(msg, dict):
310
+ # 如果已經是 dict 格式,直接使用
311
+ history_dict.append(msg)
312
+ elif isinstance(msg, tuple) and len(msg) == 2:
313
+ # 轉換 tuple 為 dict 格式
314
+ user_msg, assistant_msg = msg
315
+ history_dict.append({"role": "user", "content": user_msg})
316
+ history_dict.append({"role": "assistant", "content": assistant_msg})
317
+
318
+ return history_dict
319
+
320
+ def ensure_dict_format(history):
321
+ """
322
+ 確保歷史是 Gradio dict 格式
323
+
324
+ Args:
325
+ history: 歷史列表(可能是 dict 或 tuple 格式,也可能是 None)
326
+
327
+ Returns:
328
+ Gradio 格式的歷史(List[Dict])
329
+ """
330
+ if not history:
331
+ return []
332
+
333
+ # 檢查第一個元素的類型來判斷格式
334
+ try:
335
+ if isinstance(history[0], dict):
336
+ return history
337
+ elif isinstance(history[0], tuple):
338
+ return history_tuple_to_dict(history)
339
+ else:
340
+ # 未知格式,返回空列表
341
+ return []
342
+ except (IndexError, TypeError):
343
+ # 如果 history 為空或無法索引,返回空列表
344
+ return []
345
+
346
+ # 事件處理函數
347
+ def process_files(files, use_semantic, chunk_size, chunk_overlap, semantic_threshold, semantic_min_chunk):
348
+ """
349
+ 處理上傳的文件
350
+
351
+ Args:
352
+ files: 上傳的文件列表
353
+ use_semantic: 是否使用語義分塊
354
+ chunk_size: 字符分塊大小(僅用於字符分塊模式)
355
+ chunk_overlap: 字符分塊重疊大小(僅用於字符分塊模式)
356
+ semantic_threshold: 語義分塊閾值(僅用於語義分塊模式)
357
+ semantic_min_chunk: 語義分塊最小 chunk 大小(僅用於語義分塊模式)
358
+ """
359
+ if not files:
360
+ return "❌ 請先上傳文件", "等待上傳文件..."
361
+
362
+ try:
363
+ # 獲取 RAG ��例
364
+ rag = get_private_rag_instance()
365
+
366
+ # 更新配置
367
+ rag.use_semantic_chunking = use_semantic
368
+
369
+ # 更新分塊參數(根據分塊模式選擇)
370
+ if not use_semantic:
371
+ # 字符分塊模式:更新字符分塊參數
372
+ rag.chunk_size = int(chunk_size)
373
+ rag.chunk_overlap = int(chunk_overlap)
374
+ print(f"📏 使用字符分塊:chunk_size={rag.chunk_size}, chunk_overlap={rag.chunk_overlap}")
375
+ else:
376
+ # 語義分塊模式:更新語義分塊參數
377
+ rag.semantic_threshold = float(semantic_threshold)
378
+ rag.semantic_min_chunk_size = int(semantic_min_chunk)
379
+ print(f"📏 使用語義分塊:threshold={rag.semantic_threshold}, min_chunk_size={rag.semantic_min_chunk_size}")
380
+
381
+ # 處理上傳的文件(Gradio 會自動保存到臨時目錄)
382
+ # Gradio 6.x 返回的是文件路徑字符串列表
383
+ file_paths = []
384
+
385
+ for file in files:
386
+ # Gradio 6.x 返回字符串路徑,舊版本可能返回文件對象
387
+ if isinstance(file, str):
388
+ file_path = file
389
+ elif hasattr(file, 'name'):
390
+ # 舊版本 Gradio 文件對象
391
+ file_path = file.name
392
+ else:
393
+ # 嘗試轉換為字符串
394
+ file_path = str(file)
395
+
396
+ if os.path.exists(file_path):
397
+ file_paths.append(file_path)
398
+ else:
399
+ return f"❌ 文件不存在: {file_path}", "處理失敗"
400
+
401
+ if not file_paths:
402
+ return "❌ 沒有有效的文件路徑", "處理失敗"
403
+
404
+ # 處理文件
405
+ documents, status_msg = rag.process_files(file_paths)
406
+
407
+ if documents:
408
+ return status_msg, "✅ 文件處理完成,可以開始查詢"
409
+ else:
410
+ return status_msg, "❌ 處理失敗"
411
+
412
+ except Exception as e:
413
+ error_msg = f"❌ 處理文件時發生錯誤: {str(e)}"
414
+ print(error_msg)
415
+ import traceback
416
+ traceback.print_exc()
417
+ return error_msg, "❌ 處理失敗"
418
+
419
+ def query_rag_stream(message, history, top_k, use_llm, enable_adaptive, manual_method):
420
+ """
421
+ 查詢 RAG 系統(對話式,流式輸出)
422
+
423
+ Args:
424
+ message: 當前用戶消息
425
+ history: 對話歷史(Gradio 格式:List[Dict] 或 List[Tuple[str, str]])
426
+ top_k: 返回結果數量
427
+ use_llm: 是否使用 LLM 生成回答
428
+ enable_adaptive: 是否啟用自動選擇
429
+ manual_method: 手動選擇的方法(僅在自動選擇關閉時生效)
430
+
431
+ Yields:
432
+ Tuple[history, status_msg]: 逐步更新的對話歷史和狀態訊息
433
+ """
434
+ if not message or not message.strip():
435
+ yield history, "❌ 請輸入問題"
436
+ return
437
+
438
+ try:
439
+ # 獲取 RAG 實例
440
+ rag = get_private_rag_instance()
441
+
442
+ if not rag.is_initialized:
443
+ error_msg = "❌ RAG 系統尚未初始化,請先處理文件"
444
+ # 確保 history 是 dict 格式
445
+ history = ensure_dict_format(history)
446
+ history.append({"role": "user", "content": message})
447
+ history.append({"role": "assistant", "content": error_msg})
448
+ yield history, error_msg
449
+ return
450
+
451
+ # 設置 RAG 方法選擇參數
452
+ rag.enable_adaptive_selection = enable_adaptive
453
+ if not enable_adaptive:
454
+ rag.selected_rag_method = manual_method
455
+ else:
456
+ rag.selected_rag_method = None
457
+
458
+ # 準備對話歷史:轉換為 RAG 需要的 tuple 格式
459
+ conversation_history = history_dict_to_tuple(history) if history else []
460
+
461
+ # 確保 history 是 dict 格式並添加用戶消息
462
+ history = ensure_dict_format(history)
463
+ history.append({"role": "user", "content": message})
464
+
465
+ # 執行查詢(傳入對話歷史,使用流式輸出)
466
+ if use_llm:
467
+ # 使用流式查詢
468
+ answer_generator = rag.query_stream(
469
+ query=message,
470
+ top_k=int(top_k),
471
+ conversation_history=conversation_history
472
+ )
473
+
474
+ # 初始化回答
475
+ accumulated_answer = ""
476
+ history_with_user = history.copy()
477
+ final_result = {}
478
+
479
+ # 逐步接收流式回答
480
+ for chunk in answer_generator:
481
+ if chunk.get("success") is False:
482
+ error = chunk.get("error", "未知錯誤")
483
+ error_msg = f"❌ 查詢失敗: {error}"
484
+ history_with_user.append({"role": "assistant", "content": error_msg})
485
+ yield history_with_user, error_msg
486
+ return
487
+
488
+ # 保存最後一個 chunk 作為最終結果
489
+ final_result = chunk
490
+
491
+ # 獲取新的回答片段
492
+ new_answer = chunk.get("answer", "")
493
+ if new_answer:
494
+ # 累積回答
495
+ accumulated_answer = new_answer
496
+ # 更新歷史
497
+ history_with_answer = history_with_user.copy()
498
+ history_with_answer.append({"role": "assistant", "content": accumulated_answer})
499
+ yield history_with_answer, "🔄 正在生成回答..."
500
+
501
+ # 獲取最終結果(包含統計信息)
502
+ rag_method = final_result.get("rag_method", "basic")
503
+ stats = final_result.get("stats", {})
504
+ status_msg = f"✅ 查詢完成(方法: {rag_method.upper()})"
505
+ if stats:
506
+ total_time = stats.get("total_time", 0)
507
+ if total_time > 0:
508
+ status_msg += f" | 耗時: {total_time:.2f}秒"
509
+
510
+ # 確保最終回答完整
511
+ if accumulated_answer:
512
+ history_with_answer = history_with_user.copy()
513
+ history_with_answer.append({"role": "assistant", "content": accumulated_answer})
514
+ yield history_with_answer, status_msg
515
+ else:
516
+ error_msg = "⚠️ LLM 未生成回答(可能 LLM 服務未啟動)"
517
+ history_with_answer = history_with_user.copy()
518
+ history_with_answer.append({"role": "assistant", "content": error_msg})
519
+ yield history_with_answer, status_msg
520
+ else:
521
+ # 不使用 LLM,直接返回檢索結果
522
+ result = rag.query(
523
+ query=message,
524
+ top_k=int(top_k),
525
+ use_llm=False,
526
+ conversation_history=conversation_history
527
+ )
528
+
529
+ if not result.get("success"):
530
+ error = result.get("error", "未知錯誤")
531
+ error_msg = f"❌ 查詢失敗: {error}"
532
+ history.append({"role": "assistant", "content": error_msg})
533
+ yield history, error_msg
534
+ return
535
+
536
+ # 格式化檢索結果
537
+ formatted_context = result.get("formatted_context", "")
538
+ answer = f"📄 檢索到的相關內容:\n\n{formatted_context}"
539
+
540
+ # 獲取 RAG 方法信息
541
+ rag_method = result.get("rag_method", "basic")
542
+ stats = result.get("stats", {})
543
+ status_msg = f"✅ 查詢完成(方法: {rag_method.upper()})"
544
+ if stats:
545
+ total_time = stats.get("total_time", 0)
546
+ if total_time > 0:
547
+ status_msg += f" | 耗時: {total_time:.2f}秒"
548
+
549
+ history.append({"role": "assistant", "content": answer})
550
+ yield history, status_msg
551
+
552
+ except Exception as e:
553
+ error_msg = f"❌ 查詢時發生錯誤: {str(e)}"
554
+ print(error_msg)
555
+ import traceback
556
+ traceback.print_exc()
557
+ # 確保 history 是 dict 格式
558
+ history = ensure_dict_format(history)
559
+ if not any(msg.get("role") == "user" and msg.get("content") == message for msg in history):
560
+ history.append({"role": "user", "content": message})
561
+ history.append({"role": "assistant", "content": error_msg})
562
+ yield history, error_msg
563
+
564
+ def clear_chat():
565
+ """清除對話歷史(不重置 RAG 系統)"""
566
+ return [], "對話已清除"
567
+
568
+ def clear_all():
569
+ """清除所有內容(包括 RAG 系統)"""
570
+ reset_private_rag_instance()
571
+ empty_history = []
572
+ return (
573
+ None, # file_upload
574
+ False, # use_semantic_chunking
575
+ 500, # chunk_size_slider
576
+ 100, # chunk_overlap_slider
577
+ 1.0, # semantic_threshold_slider
578
+ 100, # semantic_min_chunk_slider
579
+ True, # enable_adaptive_selection
580
+ "basic", # manual_rag_method
581
+ "等待上傳文件...", # process_status
582
+ empty_history, # chatbot (對話歷史)
583
+ empty_history, # chat_history (狀態)
584
+ "等��查詢...", # query_status
585
+ )
586
+
587
+ # 綁定事件
588
+ process_btn.click(
589
+ fn=process_files,
590
+ inputs=[
591
+ file_upload,
592
+ use_semantic_chunking,
593
+ chunk_size_slider,
594
+ chunk_overlap_slider,
595
+ semantic_threshold_slider,
596
+ semantic_min_chunk_slider
597
+ ],
598
+ outputs=[process_status, query_status]
599
+ )
600
+
601
+ # 自動選擇開關時顯示/隱藏手動選擇下拉菜單
602
+ def toggle_manual_method(enable_adaptive):
603
+ return gr.update(visible=not enable_adaptive)
604
+
605
+ enable_adaptive_selection.change(
606
+ fn=toggle_manual_method,
607
+ inputs=[enable_adaptive_selection],
608
+ outputs=[manual_rag_method]
609
+ )
610
+
611
+ # 提交消息(按鈕點擊或 Enter 鍵)
612
+ def submit_message(message, history, top_k, use_llm, enable_adaptive, manual_method):
613
+ """提交消息並更新對話歷史(流式輸出)"""
614
+ if not message or not message.strip():
615
+ # 確保 history 是 dict 格式
616
+ history = ensure_dict_format(history)
617
+ return history, history, "", "等待查詢..."
618
+ # 清空輸入框並執行流式查詢
619
+ for new_history, status in query_rag_stream(message, history, top_k, use_llm, enable_adaptive, manual_method):
620
+ yield new_history, new_history, "", status
621
+
622
+ # 綁定提交按鈕和 Enter 鍵
623
+ submit_btn.click(
624
+ fn=submit_message,
625
+ inputs=[msg, chat_history, top_k_slider, use_llm_checkbox, enable_adaptive_selection, manual_rag_method],
626
+ outputs=[chatbot, chat_history, msg, query_status]
627
+ )
628
+
629
+ msg.submit(
630
+ fn=submit_message,
631
+ inputs=[msg, chat_history, top_k_slider, use_llm_checkbox, enable_adaptive_selection, manual_rag_method],
632
+ outputs=[chatbot, chat_history, msg, query_status]
633
+ )
634
+
635
+ # 清除對話按鈕(需要更新 chat_history 狀態)
636
+ def clear_chat_with_state():
637
+ """清除對話歷史並更新狀態"""
638
+ empty_history = []
639
+ return empty_history, empty_history, "對話已清除"
640
+
641
+ clear_chat_btn.click(
642
+ fn=clear_chat_with_state,
643
+ outputs=[chatbot, chat_history, query_status]
644
+ )
645
+
646
+ # 清除所有按鈕
647
+ clear_files_btn.click(
648
+ fn=clear_all,
649
+ outputs=[
650
+ file_upload,
651
+ use_semantic_chunking,
652
+ chunk_size_slider,
653
+ chunk_overlap_slider,
654
+ semantic_threshold_slider,
655
+ semantic_min_chunk_slider,
656
+ enable_adaptive_selection,
657
+ manual_rag_method,
658
+ process_status,
659
+ chatbot, # 更新 chatbot 顯示
660
+ chat_history, # 更新 chat_history 狀態
661
+ query_status
662
+ ]
663
+ )
deep_agent_rag/utils/llm_utils.py CHANGED
@@ -1,11 +1,12 @@
1
  """
2
  LLM 工具函數
3
  提供 LLM 實例的創建和管理
4
- 優先使用 Groq API,額度用完後自動切換到本地 MLX 模型
5
  """
6
  import warnings
7
  from typing import Optional
8
  from langchain_groq import ChatGroq
 
9
  from ..models import MLXChatModel, load_mlx_model
10
  from ..config import (
11
  MLX_MAX_TOKENS,
@@ -14,7 +15,12 @@ from ..config import (
14
  GROQ_MODEL,
15
  GROQ_MAX_TOKENS,
16
  GROQ_TEMPERATURE,
17
- USE_GROQ_FIRST
 
 
 
 
 
18
  )
19
 
20
  # 全局變量:跟踪當前使用的 LLM 類型
@@ -29,30 +35,17 @@ def get_llm_type() -> str:
29
 
30
  def is_using_local_llm() -> bool:
31
  """檢查是否正在使用本地 LLM"""
32
- return _current_llm_type == "mlx" or _groq_quota_exceeded
33
 
34
 
35
  def get_llm():
36
  """
37
  獲取 LLM 實例
38
- 優先使用 Groq API,額度用完後自動切換到本地 MLX 模型
39
  """
40
  global _current_llm_type, _groq_quota_exceeded
41
 
42
- # 如果已經知道 Groq 額度用完,直接使用本地模型
43
- if _groq_quota_exceeded:
44
- if _current_llm_type != "mlx":
45
- print("⚠️ 警告:Groq API 額度已用完,已切換到本地 MLX 模型 (Qwen2.5)")
46
- _current_llm_type = "mlx"
47
- model, tokenizer = load_mlx_model()
48
- return MLXChatModel(
49
- model=model,
50
- tokenizer=tokenizer,
51
- max_tokens=MLX_MAX_TOKENS,
52
- temperature=MLX_TEMPERATURE
53
- )
54
-
55
- # 嘗試使用 Groq API
56
  if USE_GROQ_FIRST and GROQ_API_KEY:
57
  try:
58
  groq_llm = ChatGroq(
@@ -61,48 +54,61 @@ def get_llm():
61
  max_tokens=GROQ_MAX_TOKENS,
62
  temperature=GROQ_TEMPERATURE
63
  )
64
- # 測試連接(通過一個簡單的調用來驗證)
65
- # 注意:這裡不實際調用,只是創建實例
66
  _current_llm_type = "groq"
67
  print("✅ 使用 Groq API (優先)")
68
  return groq_llm
69
  except Exception as e:
70
- # 如果創建失敗,可能是 API key 無效
71
  print(f"⚠️ Groq API 初始化失敗: {e}")
72
- _groq_quota_exceeded = True
73
- _current_llm_type = "mlx"
74
- print("⚠️ 警告:已切換到本地 MLX 模型 (Qwen2.5)")
75
- model, tokenizer = load_mlx_model()
76
- return MLXChatModel(
77
- model=model,
78
- tokenizer=tokenizer,
79
- max_tokens=MLX_MAX_TOKENS,
80
- temperature=MLX_TEMPERATURE
 
81
  )
82
- else:
83
- # 如果沒有配置 Groq 或選擇不使用,直接使用本地模型
84
- if not GROQ_API_KEY:
85
- print("ℹ️ 未配置 GROQ_API_KEY,使用本地 MLX 模型")
86
- _current_llm_type = "mlx"
87
- model, tokenizer = load_mlx_model()
88
- return MLXChatModel(
89
- model=model,
90
- tokenizer=tokenizer,
91
- max_tokens=MLX_MAX_TOKENS,
92
- temperature=MLX_TEMPERATURE
93
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
 
96
  def handle_groq_error(error: Exception) -> Optional[MLXChatModel]:
97
  """
98
  處理 Groq API 錯誤
99
- 如果是額度用完錯誤,切換到本地模型
100
 
101
  Args:
102
  error: 捕獲的異常
103
 
104
  Returns:
105
- 如果切換到本地模型,返回 MLXChatModel;否則返回 None
106
  """
107
  global _current_llm_type, _groq_quota_exceeded
108
 
@@ -121,10 +127,27 @@ def handle_groq_error(error: Exception) -> Optional[MLXChatModel]:
121
  if any(indicator in error_str for indicator in quota_indicators):
122
  if not _groq_quota_exceeded:
123
  _groq_quota_exceeded = True
124
- warning_msg = "⚠️ 警告:Groq API 額度已用完,正在切換到本地 MLX 模型 (Qwen2.5)"
125
  print(warning_msg)
126
  warnings.warn(warning_msg, UserWarning)
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  _current_llm_type = "mlx"
129
  model, tokenizer = load_mlx_model()
130
  return MLXChatModel(
 
1
  """
2
  LLM 工具函數
3
  提供 LLM 實例的創建和管理
4
+ 優先順序:Groq API > Ollama > MLX 模型
5
  """
6
  import warnings
7
  from typing import Optional
8
  from langchain_groq import ChatGroq
9
+ from langchain_ollama import ChatOllama
10
  from ..models import MLXChatModel, load_mlx_model
11
  from ..config import (
12
  MLX_MAX_TOKENS,
 
15
  GROQ_MODEL,
16
  GROQ_MAX_TOKENS,
17
  GROQ_TEMPERATURE,
18
+ USE_GROQ_FIRST,
19
+ OLLAMA_BASE_URL,
20
+ OLLAMA_MODEL,
21
+ OLLAMA_MAX_TOKENS,
22
+ OLLAMA_TEMPERATURE,
23
+ USE_OLLAMA,
24
  )
25
 
26
  # 全局變量:跟踪當前使用的 LLM 類型
 
35
 
36
  def is_using_local_llm() -> bool:
37
  """檢查是否正在使用本地 LLM"""
38
+ return _current_llm_type in ["mlx", "ollama"] or _groq_quota_exceeded
39
 
40
 
41
  def get_llm():
42
  """
43
  獲取 LLM 實例
44
+ 優先順序:Groq API > Ollama > MLX 模型
45
  """
46
  global _current_llm_type, _groq_quota_exceeded
47
 
48
+ # 優先順序 1: Groq API
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  if USE_GROQ_FIRST and GROQ_API_KEY:
50
  try:
51
  groq_llm = ChatGroq(
 
54
  max_tokens=GROQ_MAX_TOKENS,
55
  temperature=GROQ_TEMPERATURE
56
  )
 
 
57
  _current_llm_type = "groq"
58
  print("✅ 使用 Groq API (優先)")
59
  return groq_llm
60
  except Exception as e:
61
+ # 如果創建失敗,繼續嘗試其他選項
62
  print(f"⚠️ Groq API 初始化失敗: {e}")
63
+ # 不立即設置 _groq_quota_exceeded,先嘗試 Ollama
64
+
65
+ # 優先順序 2: Ollama (Llama 3.2 或其他模型)
66
+ if USE_OLLAMA:
67
+ try:
68
+ ollama_llm = ChatOllama(
69
+ base_url=OLLAMA_BASE_URL,
70
+ model=OLLAMA_MODEL,
71
+ num_predict=OLLAMA_MAX_TOKENS,
72
+ temperature=OLLAMA_TEMPERATURE,
73
  )
74
+ _current_llm_type = "ollama"
75
+ print(f"✅ 使用 Ollama 模型 ({OLLAMA_MODEL})")
76
+ return ollama_llm
77
+ except Exception as e:
78
+ print(f"⚠️ Ollama 初始化失敗: {e}")
79
+ print(" 請確保 Ollama 服務正在運行: ollama serve")
80
+ print(" 或檢查模型是否已下載: ollama pull " + OLLAMA_MODEL)
81
+
82
+ # 優先順序 3: MLX 模型(備援)
83
+ # 如果 Groq 額度用完,記錄狀態
84
+ if _groq_quota_exceeded and _current_llm_type != "mlx":
85
+ print("⚠️ 警告:Groq API 額度已用完,已切換到本地 MLX 模型 (Qwen2.5)")
86
+ elif _current_llm_type != "mlx":
87
+ if not GROQ_API_KEY and not USE_OLLAMA:
88
+ print("ℹ️ 未配置 Groq API 或 Ollama,使用本地 MLX 模型")
89
+ elif not USE_OLLAMA:
90
+ print("ℹ️ Ollama 未啟用,使用本地 MLX 模型作為備援")
91
+
92
+ _current_llm_type = "mlx"
93
+ model, tokenizer = load_mlx_model()
94
+ return MLXChatModel(
95
+ model=model,
96
+ tokenizer=tokenizer,
97
+ max_tokens=MLX_MAX_TOKENS,
98
+ temperature=MLX_TEMPERATURE
99
+ )
100
 
101
 
102
  def handle_groq_error(error: Exception) -> Optional[MLXChatModel]:
103
  """
104
  處理 Groq API 錯誤
105
+ 如果是額度用完錯誤,先嘗試切換到 Ollama,否則切換到 MLX 模型
106
 
107
  Args:
108
  error: 捕獲的異常
109
 
110
  Returns:
111
+ 如果切換到本地模型,返回 ChatOllama 或 MLXChatModel;否則返回 None
112
  """
113
  global _current_llm_type, _groq_quota_exceeded
114
 
 
127
  if any(indicator in error_str for indicator in quota_indicators):
128
  if not _groq_quota_exceeded:
129
  _groq_quota_exceeded = True
130
+ warning_msg = "⚠️ 警告:Groq API 額度已用完"
131
  print(warning_msg)
132
  warnings.warn(warning_msg, UserWarning)
133
 
134
+ # 先嘗試使用 Ollama
135
+ if USE_OLLAMA:
136
+ try:
137
+ ollama_llm = ChatOllama(
138
+ base_url=OLLAMA_BASE_URL,
139
+ model=OLLAMA_MODEL,
140
+ num_predict=OLLAMA_MAX_TOKENS,
141
+ temperature=OLLAMA_TEMPERATURE,
142
+ )
143
+ _current_llm_type = "ollama"
144
+ print(f"✅ 已切換到 Ollama 模型 ({OLLAMA_MODEL})")
145
+ return ollama_llm
146
+ except Exception as e:
147
+ print(f"⚠️ Ollama 切換失敗: {e}")
148
+ print(" 回退到 MLX 模型")
149
+
150
+ # 回退到 MLX 模型
151
  _current_llm_type = "mlx"
152
  model, tokenizer = load_mlx_model()
153
  return MLXChatModel(
pyproject.toml CHANGED
@@ -17,6 +17,7 @@ dependencies = [
17
  "yfinance>=0.2.66",
18
  "langgraph>=1.0.4",
19
  "langchain-groq>=1.1.0",
 
20
  "grandalf>=0.8",
21
  "langserve[all]>=0.3.3",
22
  "fastapi>=0.124.2",
@@ -36,5 +37,15 @@ dependencies = [
36
  "google-auth-httplib2>=0.3.0",
37
  "google-auth-oauthlib>=1.2.3",
38
  "googlemaps>=4.10.0",
 
39
  "parlant @ git+https://github.com/emcie-co/parlant@develop",
 
 
 
 
 
 
 
 
 
40
  ]
 
17
  "yfinance>=0.2.66",
18
  "langgraph>=1.0.4",
19
  "langchain-groq>=1.1.0",
20
+ "langchain-ollama>=0.1.0",
21
  "grandalf>=0.8",
22
  "langserve[all]>=0.3.3",
23
  "fastapi>=0.124.2",
 
37
  "google-auth-httplib2>=0.3.0",
38
  "google-auth-oauthlib>=1.2.3",
39
  "googlemaps>=4.10.0",
40
+ <<<<<<< HEAD
41
  "parlant @ git+https://github.com/emcie-co/parlant@develop",
42
+ =======
43
+ # Learn_RAG 項目依賴(用於私有文件 RAG 功能)
44
+ "arxiv>=2.3.1",
45
+ "langchain-text-splitters>=0.0.1",
46
+ "rank-bm25>=0.2.2",
47
+ "chromadb>=0.4.22",
48
+ "docx2txt>=0.8",
49
+ "langchain-experimental>=0.0.50",
50
+ >>>>>>> 5beccbe9dfa0ef53e4123976ad54e2f1c28b72f8
51
  ]
src/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RAG 系統模組套件
3
+ """
4
+ from .document_processor import DocumentProcessor
5
+ from .retrievers import (
6
+ BaseRetriever,
7
+ BM25Retriever,
8
+ VectorRetriever,
9
+ HybridSearch,
10
+ Reranker,
11
+ RAGPipeline,
12
+ )
13
+ from .prompt_formatter import PromptFormatter
14
+ from .llm_integration import OllamaLLM
15
+ from .subquery_rag import SubQueryDecompositionRAG
16
+ from .hyde_rag import HyDERAG
17
+ from .hybrid_subquery_hyde_rag import HybridSubqueryHyDERAG
18
+ from .step_back_rag import StepBackRAG
19
+ from .triple_hybrid_rag import TripleHybridRAG
20
+
21
+ __all__ = [
22
+ "DocumentProcessor",
23
+ "BaseRetriever",
24
+ "BM25Retriever",
25
+ "VectorRetriever",
26
+ "HybridSearch",
27
+ "Reranker",
28
+ "RAGPipeline",
29
+ "PromptFormatter",
30
+ "OllamaLLM",
31
+ "SubQueryDecompositionRAG",
32
+ "HyDERAG",
33
+ "HybridSubqueryHyDERAG",
34
+ "StepBackRAG",
35
+ "TripleHybridRAG",
36
+ ]
37
+
src/document_processor.py ADDED
@@ -0,0 +1,590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 文檔處理模組:載入 arXiv 論文並進行文字分割
3
+ 支援本地檔案:PDF, DOCX, TXT
4
+ 支援兩種分塊策略:
5
+ 1. 字符分塊(預設):基於固定字符數的分塊,速度快
6
+ 2. 語義分塊(可選):基於語義相似度的分塊,能保持語義完整性
7
+ """
8
+ from typing import List, Dict, Optional, Any
9
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
10
+ from pathlib import Path
11
+ import os
12
+ import arxiv
13
+ import re
14
+
15
+ # 嘗試導入語義分塊器(需要 langchain-experimental)
16
+ try:
17
+ from langchain_experimental.text_splitter import SemanticChunker
18
+ SEMANTIC_CHUNKER_AVAILABLE = True
19
+ except ImportError:
20
+ SEMANTIC_CHUNKER_AVAILABLE = False
21
+
22
+
23
+ class DocumentProcessor:
24
+ """
25
+ 處理 arXiv 論文文檔,進行分割和準備
26
+
27
+ 支援兩種分塊模式:
28
+ - 字符分塊(預設):快速、穩定,適合大多數場景
29
+ - 語義分塊(可選):更智能,能保持語義完整性,但需要額外依賴和計算時間
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ chunk_size: int = 1000,
35
+ chunk_overlap: int = 200,
36
+ embeddings: Optional[Any] = None, # 可選:用於語義分塊的 embedding 模型
37
+ use_semantic_chunking: bool = False, # 是否使用語義分塊
38
+ breakpoint_threshold_amount: float = 1.5, # 語義分塊敏感度(標準差倍數)
39
+ min_chunk_size: int = 100 # 語義分塊的最小 chunk 大小(字符數)
40
+ ):
41
+ """
42
+ 初始化文檔處理器
43
+
44
+ Args:
45
+ chunk_size: 每個 chunk 的大小(字符數),僅用於字符分塊模式
46
+ chunk_overlap: chunk 之間的重疊大小(字符數),僅用於字符分塊模式
47
+ embeddings: 用於計算語義距離的 embedding 模型物件(可選)
48
+ 當 use_semantic_chunking=True 時必須提供
49
+ use_semantic_chunking: 是否使用語義分塊
50
+ True: 使用語義分塊(需要提供 embeddings)
51
+ False: 使用字符分塊(預設)
52
+ breakpoint_threshold_amount: 語義分塊的敏感度參數
53
+ 數值越大,分塊越少(chunks 越大)
54
+ 數值越小,分塊越多(chunks 越小)
55
+ 建議範圍:1.0 - 2.0,預設 1.5
56
+ min_chunk_size: 語義分塊的最小 chunk 大小(字符數)
57
+ 小於此大小的 chunks 會被合併到相鄰的 chunks
58
+ 預設 100 字符
59
+ """
60
+ self.embeddings = embeddings
61
+ self.use_semantic_chunking = use_semantic_chunking
62
+ self.min_chunk_size = min_chunk_size
63
+
64
+ # 如果要求使用語義分塊
65
+ if use_semantic_chunking:
66
+ # 檢查是否安裝了必要的依賴
67
+ if not SEMANTIC_CHUNKER_AVAILABLE:
68
+ raise ImportError(
69
+ "使用語義分塊需要安裝 langchain-experimental 套件。\n"
70
+ "請執行: pip install langchain-experimental\n"
71
+ "或使用字符分塊模式(use_semantic_chunking=False)"
72
+ )
73
+
74
+ # 檢查是否提供了 embeddings
75
+ if embeddings is None:
76
+ raise ValueError(
77
+ "使用語義分塊時必須提供 embeddings 參數。\n"
78
+ "範例:\n"
79
+ " from langchain_community.embeddings import HuggingFaceEmbeddings\n"
80
+ " embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')\n"
81
+ " processor = DocumentProcessor(embeddings=embeddings, use_semantic_chunking=True)"
82
+ )
83
+
84
+ # 初始化語義分塊器
85
+ # 使用「標準差」策略:當相鄰句子之間的語義距離超過平均距離的標準差倍數時,進行切分
86
+ self.text_splitter = SemanticChunker(
87
+ embeddings,
88
+ breakpoint_threshold_type="standard_deviation",
89
+ breakpoint_threshold_amount=breakpoint_threshold_amount
90
+ )
91
+ print(f"✓ 使用語義分塊模式(敏感度: {breakpoint_threshold_amount},最小 chunk 大小: {min_chunk_size} 字符)")
92
+ else:
93
+ # 使用傳統的字符分塊(預設模式)
94
+ self.text_splitter = RecursiveCharacterTextSplitter(
95
+ chunk_size=chunk_size,
96
+ chunk_overlap=chunk_overlap,
97
+ length_function=len,
98
+ )
99
+ print(f"✓ 使用字符分塊模式(大小: {chunk_size} 字符,重疊: {chunk_overlap} 字符)")
100
+
101
+ def _post_process_chunks(self, chunks: List[str]) -> List[str]:
102
+ """
103
+ 後處理 chunks:過濾和合併太小的 chunks
104
+
105
+ 語義分塊可能會產生一些非��小的 chunks(例如只有幾個單詞),
106
+ 這些小 chunks 可能不包含足夠的上下文資訊。此方法會:
107
+ 1. 將小於 min_chunk_size 的 chunks 合併到相鄰的 chunks
108
+ 2. 確保最終的 chunks 都有足夠的大小
109
+
110
+ Args:
111
+ chunks: 原始 chunks 列表(從分塊器產生的)
112
+
113
+ Returns:
114
+ 處理後的 chunks 列表(過濾和合併後的)
115
+ """
116
+ # 如果使用字符分塊,不需要後處理(因為已經有固定大小)
117
+ if not self.use_semantic_chunking:
118
+ return chunks
119
+
120
+ # 如果沒有 chunks,直接返回
121
+ if not chunks:
122
+ return chunks
123
+
124
+ processed = []
125
+ current_small_chunk = "" # 累積的小 chunk
126
+
127
+ for chunk in chunks:
128
+ chunk_stripped = chunk.strip()
129
+ chunk_length = len(chunk_stripped)
130
+
131
+ # 如果當前 chunk 太小,嘗試與下一個合併
132
+ if chunk_length < self.min_chunk_size:
133
+ # 累積到臨時變數中
134
+ if current_small_chunk:
135
+ current_small_chunk += "\n\n" + chunk
136
+ else:
137
+ current_small_chunk = chunk
138
+ else:
139
+ # 當前 chunk 足夠大
140
+ # 如果有累積的小 chunk,先處理它
141
+ if current_small_chunk:
142
+ current_small_chunk_stripped = current_small_chunk.strip()
143
+ if len(current_small_chunk_stripped) >= self.min_chunk_size:
144
+ # 累積後足夠大,作為獨立 chunk
145
+ processed.append(current_small_chunk)
146
+ else:
147
+ # 累積後還是太小,合併到上一個 chunk(如果存在)
148
+ if processed:
149
+ processed[-1] += "\n\n" + current_small_chunk
150
+ else:
151
+ # 如果沒有上一個 chunk,還是要保留
152
+ processed.append(current_small_chunk)
153
+ current_small_chunk = ""
154
+
155
+ # 添加當前足夠大的 chunk
156
+ processed.append(chunk)
157
+
158
+ # 處理最後的累積小 chunk
159
+ if current_small_chunk:
160
+ current_small_chunk_stripped = current_small_chunk.strip()
161
+ if len(current_small_chunk_stripped) >= self.min_chunk_size:
162
+ # 足夠大,作為獨立 chunk
163
+ processed.append(current_small_chunk)
164
+ elif processed:
165
+ # 太小,合併到最後一個 chunk
166
+ processed[-1] += "\n\n" + current_small_chunk
167
+ else:
168
+ # 如果沒有其他 chunks,還是要保留
169
+ processed.append(current_small_chunk)
170
+
171
+ return processed
172
+
173
+ def fetch_papers(self, query: str, max_results: int = 10) -> List[Dict]:
174
+ """
175
+ 從 arXiv 獲取論文
176
+
177
+ Args:
178
+ query: 搜尋查詢(例如 "cat:cs.AI")
179
+ max_results: 最大結果數量
180
+
181
+ Returns:
182
+ 論文列表,每個論文包含標題、摘要等資訊
183
+ """
184
+ search = arxiv.Search(
185
+ query=query,
186
+ max_results=max_results,
187
+ sort_by=arxiv.SortCriterion.SubmittedDate
188
+ )
189
+
190
+ papers = []
191
+ for paper in search.results():
192
+ papers.append({
193
+ "title": paper.title,
194
+ "authors": [author.name for author in paper.authors],
195
+ "summary": paper.summary,
196
+ "published": str(paper.published),
197
+ "arxiv_id": paper.entry_id.split('/')[-1],
198
+ "arxiv_url": paper.entry_id,
199
+ "pdf_url": paper.pdf_url,
200
+ "categories": paper.categories,
201
+ })
202
+
203
+ return papers
204
+
205
+ def process_documents(self, papers: List[Dict]) -> List[Dict]:
206
+ """
207
+ 處理論文,將每篇論文分割成 chunks
208
+
209
+ Args:
210
+ papers: 論文列表
211
+
212
+ Returns:
213
+ 處理後的文檔 chunks,每個 chunk 包含內容和元數據
214
+ """
215
+ documents = []
216
+
217
+ for paper in papers:
218
+ # 組合論文的完整文字(標題 + 摘要)
219
+ # 保留換行符號 \n\n 作為語義斷點的結構參考
220
+ full_text = f"Title: {paper['title']}\n\nAbstract: {paper['summary']}"
221
+
222
+ # 分割文字(根據選擇的模式:字符分塊或語義分塊)
223
+ chunks = self.text_splitter.split_text(full_text)
224
+
225
+ # 後處理:過濾和合併太小的 chunks(僅語義分塊模式)
226
+ chunks = self._post_process_chunks(chunks)
227
+
228
+ # 為每個 chunk 創建文檔物件
229
+ for i, chunk in enumerate(chunks):
230
+ doc = {
231
+ "content": chunk,
232
+ "metadata": {
233
+ "title": paper['title'],
234
+ "arxiv_id": paper['arxiv_id'],
235
+ "arxiv_url": paper['arxiv_url'],
236
+ "pdf_url": paper['pdf_url'],
237
+ "authors": paper['authors'],
238
+ "published": paper['published'],
239
+ "categories": paper['categories'],
240
+ "chunk_index": i,
241
+ "total_chunks": len(chunks),
242
+ "chunking_method": "semantic" if self.use_semantic_chunking else "character"
243
+ }
244
+ }
245
+ documents.append(doc)
246
+
247
+ return documents
248
+
249
+ def get_texts_and_metadatas(self, documents: List[Dict]):
250
+ """
251
+ 從文檔列表中提取文字和元數據
252
+
253
+ Args:
254
+ documents: 文檔列表
255
+
256
+ Returns:
257
+ (texts, metadatas) 元組
258
+ """
259
+ texts = [doc["content"] for doc in documents]
260
+ metadatas = [doc["metadata"] for doc in documents]
261
+ return texts, metadatas
262
+
263
+ @staticmethod
264
+ def clean_extracted_text(text: str) -> str:
265
+ """
266
+ 清理從 PDF/DOCX 提取的文本,移除多餘的空格和修復字符換行問題
267
+
268
+ 某些 PDF 提取工具會在每個字符之間插入空格或換行,特別是中文文本。
269
+ 此方法會:
270
+ 1. 修復「每個字符一行」的問題(將單字符行合併)
271
+ 2. 移除中文字符之間的多餘空格
272
+ 3. 保留英文單詞之間的空格
273
+ 4. 保留標點符號周圍的適當空格
274
+ 5. 保留真正的段落分隔
275
+
276
+ Args:
277
+ text: 原始提取的文本
278
+
279
+ Returns:
280
+ 清理後的文本
281
+ """
282
+ if not text:
283
+ return text
284
+
285
+ # 步驟 0: 修復「每個字符一行」的問題
286
+ # 檢測模式:每行只有一個字符(可能是中文字符、標點、或單個字母/數字)
287
+ # 將這些單字符行合併成連續文本
288
+ lines = text.split('\n')
289
+ merged_lines = []
290
+ i = 0
291
+
292
+ def is_single_char_line(line: str) -> bool:
293
+ """
294
+ 判斷是否為單字符行
295
+ 考慮:去除空格後長度 <= 3(可能是單字符+標點,或單字符+空格)
296
+ """
297
+ stripped = line.strip()
298
+ if not stripped:
299
+ return False # 空行不算
300
+ # 如果去除空格後長度 <= 3,且主要是中文字符、標點或單個字母/數字
301
+ if len(stripped) <= 3:
302
+ # 檢查是否主要是單個字符(可能帶標點或空格)
303
+ # 移除所有空格後,如果長度 <= 2,認為是單字符行
304
+ no_space = stripped.replace(' ', '')
305
+ if len(no_space) <= 2:
306
+ return True
307
+ return False
308
+
309
+ while i < len(lines):
310
+ line = lines[i]
311
+ stripped_line = line.strip()
312
+
313
+ # 如果當前行是單字符行
314
+ if is_single_char_line(line):
315
+ # 收集連續的單字符行(包括空行,因為空行可能是分隔符)
316
+ merged_chars = []
317
+ j = i
318
+ consecutive_single_chars = 0
319
+
320
+ while j < len(lines):
321
+ current_line = lines[j]
322
+ current_stripped = current_line.strip()
323
+
324
+ if is_single_char_line(current_line):
325
+ # 是單字符行,收集字符(去除空格)
326
+ char = current_stripped.replace(' ', '')
327
+ if char:
328
+ merged_chars.append(char)
329
+ consecutive_single_chars += 1
330
+ j += 1
331
+ elif not current_stripped:
332
+ # 空行:如果前面有單字符,且後面可能還有單字符,跳過空行
333
+ # 檢查下一行是否也是單字符
334
+ if j + 1 < len(lines) and is_single_char_line(lines[j + 1]):
335
+ # 空行後面還有單字符,跳過空行繼續收集
336
+ j += 1
337
+ else:
338
+ # 空行後面沒有單字符了,停止收集
339
+ break
340
+ else:
341
+ # 遇到正常行,停止收集
342
+ break
343
+
344
+ # 如果收集到多個單字符,合併它們
345
+ if len(merged_chars) > 1:
346
+ merged_text = ''.join(merged_chars)
347
+ merged_lines.append(merged_text)
348
+ i = j
349
+ continue
350
+ elif len(merged_chars) == 1 and consecutive_single_chars > 1:
351
+ # 只有一個字符但有多行(可能是空格導致的),也合併
352
+ merged_text = ''.join(merged_chars)
353
+ merged_lines.append(merged_text)
354
+ i = j
355
+ continue
356
+ else:
357
+ # 只有一個單字符,且確實只有一行,保留原樣
358
+ if merged_chars:
359
+ merged_lines.append(merged_chars[0])
360
+ i = j
361
+ continue
362
+ else:
363
+ # 正常行,直接添加
364
+ if stripped_line: # 非空行
365
+ merged_lines.append(stripped_line)
366
+ i += 1
367
+
368
+ # 重新組合文本
369
+ text = '\n'.join(merged_lines)
370
+
371
+ # 步驟 0.5: 再次處理可能的殘留問題
372
+ # 如果還有單字符行(可能是第一次處理遺漏的),再次處理
373
+ lines = text.split('\n')
374
+ final_lines = []
375
+ i = 0
376
+ while i < len(lines):
377
+ line = lines[i].strip()
378
+ if is_single_char_line(line):
379
+ # 再次收集連續的單字符行
380
+ merged_chars = []
381
+ j = i
382
+ while j < len(lines) and is_single_char_line(lines[j]):
383
+ char = lines[j].strip().replace(' ', '')
384
+ if char:
385
+ merged_chars.append(char)
386
+ j += 1
387
+
388
+ if len(merged_chars) > 1:
389
+ final_lines.append(''.join(merged_chars))
390
+ i = j
391
+ else:
392
+ if merged_chars:
393
+ final_lines.append(merged_chars[0])
394
+ i = j
395
+ else:
396
+ if line:
397
+ final_lines.append(line)
398
+ i += 1
399
+
400
+ text = '\n'.join(final_lines)
401
+
402
+ # 1. 移除中文字符之間的空格
403
+ # 匹配模式:中文字符 + 空格 + 中文字符
404
+ chinese_char_pattern = r'([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])\s+([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])'
405
+ text = re.sub(chinese_char_pattern, r'\1\2', text)
406
+
407
+ # 2. 移除中文和標點符號之間的多餘空格
408
+ # 中文 + 空格 + 標點符號
409
+ chinese_punct_pattern = r'([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])\s+([,。、;:!?""''()【】《》])'
410
+ text = re.sub(chinese_punct_pattern, r'\1\2', text)
411
+
412
+ # 標點符號 + 空格 + 中文
413
+ # 使用 re.escape 來正確處理標點符號,避免轉義序列警告
414
+ punct_chars = ',。、;:!?""''()【】《》'
415
+ punct_chinese_pattern = f'([{re.escape(punct_chars)}])\\s+([\\u4e00-\\u9fff\\u3400-\\u4dbf\\uf900-\\ufaff])'
416
+ text = re.sub(punct_chinese_pattern, r'\1\2', text)
417
+
418
+ # 3. 移除數字和中文之間的多餘空格(例如:"500 公里" -> "500公里")
419
+ number_chinese_pattern = r'(\d+)\s+([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])'
420
+ text = re.sub(number_chinese_pattern, r'\1\2', text)
421
+ chinese_number_pattern = r'([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])\s+(\d+)'
422
+ text = re.sub(chinese_number_pattern, r'\1\2', text)
423
+
424
+ # 4. 移除英文單詞內部的多餘空格(例如:"Nebula-X 跨次 元量" -> "Nebula-X 跨次元量")
425
+ # 但保留英文單詞之間的空格
426
+ # 匹配:非空格字符 + 空格 + 非空格字符(如果其中一個是中文,則移除空格)
427
+ mixed_space_pattern = r'([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])\s+([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])'
428
+ text = re.sub(mixed_space_pattern, r'\1\2', text)
429
+
430
+ # 5. 移除多個連續空格(保留單個空格,用於英文單詞之間)
431
+ text = re.sub(r' +', ' ', text)
432
+
433
+ # 6. 清理行首行尾的空格(但保留換行符)
434
+ lines = text.split('\n')
435
+ cleaned_lines = [line.strip() for line in lines]
436
+ text = '\n'.join(cleaned_lines)
437
+
438
+ # 7. 移除多個連續的換行符(保留最多兩個,用於段落分隔)
439
+ text = re.sub(r'\n{3,}', '\n\n', text)
440
+
441
+ # 8. 修復可能的殘留問題:移除中文字符之間殘留的空格
442
+ # 再次檢查並移除中文字符之間的空格(處理可能遺漏的情況)
443
+ text = re.sub(r'([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])\s+([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])', r'\1\2', text)
444
+
445
+ return text
446
+
447
+ def load_from_file(self, file_path: str) -> Dict:
448
+ """
449
+ 從本地檔案載入文檔(支援 PDF, DOCX, TXT 等)
450
+
451
+ Args:
452
+ file_path: 檔案路徑
453
+
454
+ Returns:
455
+ 文檔字典,包含內容和元數據
456
+ """
457
+ file_path = Path(file_path)
458
+
459
+ if not file_path.exists():
460
+ raise FileNotFoundError(f"檔案不存在: {file_path}")
461
+
462
+ file_ext = file_path.suffix.lower()
463
+ file_name = file_path.stem
464
+ file_size = os.path.getsize(file_path)
465
+
466
+ # 根據檔案類型選擇不同的加載器
467
+ if file_ext == '.pdf':
468
+ try:
469
+ from langchain_community.document_loaders import PyPDFLoader
470
+ loader = PyPDFLoader(str(file_path))
471
+ pages = loader.load()
472
+ # 合併所有頁面
473
+ full_text = "\n\n".join([page.page_content for page in pages])
474
+ # 清理提取的文本(移除多餘空格)
475
+ full_text = self.clean_extracted_text(full_text)
476
+ except ImportError:
477
+ raise ImportError(
478
+ "需要安裝 pypdf 來處理 PDF 檔案: pip install pypdf"
479
+ )
480
+
481
+ elif file_ext in ['.docx', '.doc']:
482
+ try:
483
+ from langchain_community.document_loaders import Docx2txtLoader
484
+ loader = Docx2txtLoader(str(file_path))
485
+ pages = loader.load()
486
+ full_text = "\n\n".join([page.page_content for page in pages])
487
+ # 清理提取的文本(移除多餘空格)
488
+ full_text = self.clean_extracted_text(full_text)
489
+ except ImportError:
490
+ raise ImportError(
491
+ "需要安裝 docx2txt 來處理 DOCX 檔案: pip install docx2txt"
492
+ )
493
+
494
+ elif file_ext == '.txt':
495
+ # 嘗試不同的編碼
496
+ encodings = ['utf-8', 'gbk', 'big5', 'latin-1']
497
+ full_text = None
498
+ for encoding in encodings:
499
+ try:
500
+ with open(file_path, 'r', encoding=encoding) as f:
501
+ full_text = f.read()
502
+ break
503
+ except UnicodeDecodeError:
504
+ continue
505
+
506
+ if full_text is None:
507
+ raise ValueError(f"無法讀取檔案,嘗試的編碼都不適用: {encodings}")
508
+
509
+ else:
510
+ raise ValueError(
511
+ f"不支援的檔案類型: {file_ext}\n"
512
+ f"支援的格式: .pdf, .docx, .doc, .txt"
513
+ )
514
+
515
+ if not full_text or len(full_text.strip()) == 0:
516
+ raise ValueError(f"檔案為空或無法提取文字: {file_path}")
517
+
518
+ return {
519
+ "title": file_name,
520
+ "content": full_text,
521
+ "file_path": str(file_path),
522
+ "file_type": file_ext,
523
+ "file_size": file_size,
524
+ }
525
+
526
+ def process_file(self, file_path: str) -> List[Dict]:
527
+ """
528
+ 處理單個檔案,分割成 chunks
529
+
530
+ Args:
531
+ file_path: 檔案路徑
532
+
533
+ Returns:
534
+ 處理後的文檔 chunks 列表
535
+ """
536
+ # 載入檔案
537
+ file_doc = self.load_from_file(file_path)
538
+
539
+ # 分割文字(根據選擇的模式:字符分塊或語義分塊)
540
+ chunks = self.text_splitter.split_text(file_doc["content"])
541
+
542
+ # 後處理:過濾和合併太小的 chunks(僅語義分塊模式)
543
+ chunks = self._post_process_chunks(chunks)
544
+
545
+ if not chunks:
546
+ raise ValueError(f"檔案分割後沒有內容: {file_path}")
547
+
548
+ # 創建文檔 chunks
549
+ documents = []
550
+ for i, chunk in enumerate(chunks):
551
+ doc = {
552
+ "content": chunk,
553
+ "metadata": {
554
+ "title": file_doc["title"],
555
+ "file_path": file_doc["file_path"],
556
+ "file_type": file_doc["file_type"],
557
+ "file_size": file_doc["file_size"],
558
+ "chunk_index": i,
559
+ "total_chunks": len(chunks),
560
+ "chunking_method": "semantic" if self.use_semantic_chunking else "character"
561
+ }
562
+ }
563
+ documents.append(doc)
564
+
565
+ return documents
566
+
567
+ def process_files(self, file_paths: List[str]) -> List[Dict]:
568
+ """
569
+ 處理多個檔案
570
+
571
+ Args:
572
+ file_paths: 檔案路徑列表
573
+
574
+ Returns:
575
+ 所有檔案的文檔 chunks 列表
576
+ """
577
+ all_documents = []
578
+ for file_path in file_paths:
579
+ try:
580
+ print(f"處理檔案: {file_path}")
581
+ documents = self.process_file(file_path)
582
+ all_documents.extend(documents)
583
+ print(f" ✓ 創建了 {len(documents)} 個 chunks")
584
+ except Exception as e:
585
+ print(f" ✗ 處理檔案失敗: {file_path}")
586
+ print(f" 錯誤: {e}")
587
+ continue
588
+
589
+ return all_documents
590
+
src/hybrid_subquery_hyde_rag.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hybrid Sub-query + HyDE RAG:融合 Sub-query Decomposition 和 HyDE
3
+ 結合兩種方法的優勢,提升檢索精度
4
+ """
5
+ from typing import List, Dict, Optional
6
+ from .retrievers.reranker import RAGPipeline
7
+ from .retrievers.vector_retriever import VectorRetriever
8
+ from .prompt_formatter import PromptFormatter
9
+ from .llm_integration import OllamaLLM
10
+ import hashlib
11
+ import time
12
+ import logging
13
+ from concurrent.futures import ThreadPoolExecutor, as_completed
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class HybridSubqueryHyDERAG:
19
+ """融合 Sub-query Decomposition 和 HyDE 的 RAG 系統"""
20
+
21
+ def __init__(
22
+ self,
23
+ rag_pipeline: RAGPipeline,
24
+ vector_retriever: VectorRetriever,
25
+ llm: OllamaLLM,
26
+ max_sub_queries: int = 3,
27
+ top_k_per_subquery: int = 5,
28
+ hypothetical_length: int = 200,
29
+ temperature_subquery: float = 0.3,
30
+ temperature_hyde: float = 0.7,
31
+ enable_parallel: bool = True
32
+ ):
33
+ """
34
+ 初始化融合 RAG
35
+
36
+ Args:
37
+ rag_pipeline: RAG 管線實例
38
+ vector_retriever: 向量檢索器
39
+ llm: LLM 實例
40
+ max_sub_queries: 最多生成的子問題數量
41
+ top_k_per_subquery: 每個子問題檢索的結果數量
42
+ hypothetical_length: 假設性文檔目標長度(字符數)
43
+ temperature_subquery: 生成子問題的溫度(較低,更穩定)
44
+ temperature_hyde: 生成假設性文檔的溫度(較高,更多專業術語)
45
+ enable_parallel: 是否並行處理
46
+ """
47
+ self.rag_pipeline = rag_pipeline
48
+ self.vector_retriever = vector_retriever
49
+ self.llm = llm
50
+ self.max_sub_queries = max_sub_queries
51
+ self.top_k_per_subquery = top_k_per_subquery
52
+ self.hypothetical_length = hypothetical_length
53
+ self.temperature_subquery = temperature_subquery
54
+ self.temperature_hyde = temperature_hyde
55
+ self.enable_parallel = enable_parallel
56
+
57
+ def _generate_sub_queries(self, question: str) -> List[str]:
58
+ """
59
+ 生成子問題(與 SubQueryDecompositionRAG 相同)
60
+
61
+ Args:
62
+ question: 原始問題
63
+
64
+ Returns:
65
+ 子問題列表
66
+ """
67
+ is_chinese = PromptFormatter.detect_language(question) == "zh"
68
+
69
+ if is_chinese:
70
+ prompt = f"""你是一個專業助理。請將以下原始問題拆解成最多 {self.max_sub_queries} 個具體的子問題,以便進行資料搜尋。
71
+ 每個子問題應專注於原始問題的一個特定面向。請以換行符號分隔問題。
72
+
73
+ 原始問題: {question}
74
+
75
+ 子問題清單:"""
76
+ else:
77
+ prompt = f"""You are a professional assistant. Please decompose the following original question into at most {self.max_sub_queries} specific sub-questions for information retrieval.
78
+ Each sub-question should focus on a specific aspect of the original question. Please separate questions with newlines.
79
+
80
+ Original question: {question}
81
+
82
+ Sub-question list:"""
83
+
84
+ try:
85
+ response = self.llm.generate(
86
+ prompt=prompt,
87
+ temperature=self.temperature_subquery,
88
+ max_tokens=500
89
+ )
90
+
91
+ sub_queries = [
92
+ q.strip()
93
+ for q in response.strip().split("\n")
94
+ if q.strip() and not q.strip().startswith("#")
95
+ ]
96
+
97
+ # 移除編號前綴(如 "1. ", "1) " 等)
98
+ cleaned_queries = []
99
+ for q in sub_queries:
100
+ q = q.lstrip("0123456789. )")
101
+ q = q.strip()
102
+ if q:
103
+ cleaned_queries.append(q)
104
+
105
+ cleaned_queries = cleaned_queries[:self.max_sub_queries]
106
+
107
+ if not cleaned_queries:
108
+ logger.warning("⚠️ 未生成子問題,使用原始問題")
109
+ cleaned_queries = [question]
110
+
111
+ return cleaned_queries
112
+
113
+ except Exception as e:
114
+ logger.error(f"⚠️ 生成子問題時出錯: {e}")
115
+ return [question]
116
+
117
+ def _generate_hypothetical_document(self, sub_query: str) -> str:
118
+ """
119
+ 為子問題生成假設性文檔(與 HyDERAG 相同)
120
+
121
+ Args:
122
+ sub_query: 子問題
123
+
124
+ Returns:
125
+ 假設性文檔文本
126
+ """
127
+ is_chinese = PromptFormatter.detect_language(sub_query) == "zh"
128
+
129
+ if is_chinese:
130
+ prompt = f"""請針對以下問題,寫出一段約 {self.hypothetical_length} 字的專業技術檔案內容。
131
+ 這段內容應包含該領域常見的專業術語與原理說明,以便用於後續的語義檢索。
132
+ 請使用專業的術語和概念,即使你對某些細節不確定,也要包含相關的專業詞彙。
133
+
134
+ 問題: {sub_query}
135
+
136
+ 專業技術內容:"""
137
+ else:
138
+ prompt = f"""Please write a professional technical document of approximately {self.hypothetical_length} words in response to the following question.
139
+ This content should include common professional terminology and principle explanations in this field, to be used for subsequent semantic retrieval.
140
+ Please use professional terms and concepts, and include relevant professional vocabulary even if you are uncertain about some details.
141
+
142
+ Question: {sub_query}
143
+
144
+ Professional technical content:"""
145
+
146
+ try:
147
+ hypothetical_doc = self.llm.generate(
148
+ prompt=prompt,
149
+ temperature=self.temperature_hyde,
150
+ max_tokens=500
151
+ )
152
+
153
+ hypothetical_doc = hypothetical_doc.strip()
154
+
155
+ if not hypothetical_doc:
156
+ logger.warning(f"⚠️ 子問題 '{sub_query}' 的假設性文檔為空,使用子問題本身")
157
+ return sub_query
158
+
159
+ logger.debug(f"✅ 為子問題生成假設性文檔(長度: {len(hypothetical_doc)} 字符)")
160
+ return hypothetical_doc
161
+
162
+ except Exception as e:
163
+ logger.error(f"⚠️ 生成假設性文檔時出錯: {e}")
164
+ return sub_query
165
+
166
+ def _get_doc_id(self, doc: Dict) -> str:
167
+ """
168
+ 生成文檔的唯一標識符
169
+
170
+ Args:
171
+ doc: 文檔字典
172
+
173
+ Returns:
174
+ 唯一 ID
175
+ """
176
+ metadata = doc.get("metadata", {})
177
+ content = doc.get("content", "")
178
+
179
+ if "arxiv_id" in metadata and "chunk_index" in metadata:
180
+ return f"{metadata['arxiv_id']}_{metadata['chunk_index']}"
181
+ elif "file_path" in metadata and "chunk_index" in metadata:
182
+ return f"{metadata['file_path']}_{metadata['chunk_index']}"
183
+ else:
184
+ content_hash = hashlib.md5(content.encode()).hexdigest()[:16]
185
+ return f"doc_{content_hash}"
186
+
187
+ def _process_subquery_with_hyde(
188
+ self,
189
+ sub_query: str,
190
+ metadata_filter: Optional[Dict] = None
191
+ ) -> tuple:
192
+ """
193
+ 處理單個子問題:生成假設性文檔並檢索
194
+
195
+ Args:
196
+ sub_query: 子問題
197
+ metadata_filter: 可選的 metadata 過濾條件
198
+
199
+ Returns:
200
+ (檢索結果列表, 假設性文檔)
201
+ """
202
+ try:
203
+ # 生成假設性文檔
204
+ hypothetical_doc = self._generate_hypothetical_document(sub_query)
205
+
206
+ # 使用假設性文檔檢索
207
+ results = self.vector_retriever.retrieve(
208
+ query=hypothetical_doc, # 使用假設性文檔而不是子問題
209
+ top_k=self.top_k_per_subquery,
210
+ metadata_filter=metadata_filter
211
+ )
212
+
213
+ return results, hypothetical_doc
214
+
215
+ except Exception as e:
216
+ logger.error(f"⚠️ 處理子問題 '{sub_query}' 時出錯: {e}")
217
+ return [], ""
218
+
219
+ def query(
220
+ self,
221
+ question: str,
222
+ top_k: int = 5,
223
+ metadata_filter: Optional[Dict] = None,
224
+ return_sub_queries: bool = False,
225
+ return_hypothetical: bool = False
226
+ ) -> Dict:
227
+ """
228
+ 執行融合 RAG 檢索(不生成答案)
229
+
230
+ Args:
231
+ question: 原始問題
232
+ top_k: 返回前 k 個結果
233
+ metadata_filter: 可選的 metadata 過濾條件
234
+ return_sub_queries: 是否返回子問題列表
235
+ return_hypothetical: 是否返回假設性文檔字典(子問題 -> 假設性文檔)
236
+
237
+ Returns:
238
+ 包含檢索結果和統計資訊的字典
239
+ """
240
+ start_time = time.time()
241
+
242
+ # 第一步:生成子問題
243
+ logger.info(f"🔍 拆解問題: '{question}'")
244
+ sub_queries = self._generate_sub_queries(question)
245
+ logger.info(f"✅ 生成 {len(sub_queries)} 個子問題")
246
+
247
+ # 第二步:為每個子問題生成假設性文檔並檢索
248
+ logger.info(f"📚 為每個子問題生成假設性文檔並檢索...")
249
+ unique_docs = {}
250
+ hypothetical_docs = {}
251
+
252
+ if self.enable_parallel and len(sub_queries) > 1:
253
+ # 並行處理
254
+ logger.info(f"🔄 並行處理 {len(sub_queries)} 個子問題...")
255
+ with ThreadPoolExecutor(max_workers=min(len(sub_queries), 5)) as executor:
256
+ future_to_query = {
257
+ executor.submit(self._process_subquery_with_hyde, sq, metadata_filter): sq
258
+ for sq in sub_queries
259
+ }
260
+
261
+ for future in as_completed(future_to_query):
262
+ sub_query = future_to_query[future]
263
+ try:
264
+ results, hypo_doc = future.result()
265
+ hypothetical_docs[sub_query] = hypo_doc
266
+
267
+ logger.debug(f"✅ 子問題 '{sub_query}' 找到 {len(results)} 個結果")
268
+
269
+ for doc in results:
270
+ doc_id = self._get_doc_id(doc)
271
+ if doc_id not in unique_docs:
272
+ unique_docs[doc_id] = doc
273
+ else:
274
+ # 保留分數更高的
275
+ existing_score = unique_docs[doc_id].get('score', 0)
276
+ new_score = doc.get('score', 0)
277
+ if new_score > existing_score:
278
+ unique_docs[doc_id] = doc
279
+ except Exception as e:
280
+ logger.error(f"⚠️ 處理子問題 '{sub_query}' 時出錯: {e}")
281
+ else:
282
+ # 串行處理
283
+ logger.info(f"🔄 串行處理 {len(sub_queries)} 個子問題...")
284
+ for sub_query in sub_queries:
285
+ results, hypo_doc = self._process_subquery_with_hyde(sub_query, metadata_filter)
286
+ hypothetical_docs[sub_query] = hypo_doc
287
+
288
+ logger.debug(f"✅ 子問題 '{sub_query}' 找到 {len(results)} 個結果")
289
+
290
+ for doc in results:
291
+ doc_id = self._get_doc_id(doc)
292
+ if doc_id not in unique_docs:
293
+ unique_docs[doc_id] = doc
294
+ else:
295
+ existing_score = unique_docs[doc_id].get('score', 0)
296
+ new_score = doc.get('score', 0)
297
+ if new_score > existing_score:
298
+ unique_docs[doc_id] = doc
299
+
300
+ # 第三步:排序並返回前 top_k
301
+ result_list = list(unique_docs.values())
302
+ result_list.sort(key=lambda x: x.get('score', 0), reverse=True)
303
+ final_results = result_list[:top_k]
304
+
305
+ elapsed_time = time.time() - start_time
306
+ logger.info(f"✅ 找到 {len(final_results)} 個唯一文檔(去重後,總共 {len(result_list)} 個)")
307
+
308
+ return {
309
+ "results": final_results,
310
+ "total_docs_found": len(result_list),
311
+ "sub_queries": sub_queries if return_sub_queries else None,
312
+ "hypothetical_documents": hypothetical_docs if return_hypothetical else None,
313
+ "elapsed_time": elapsed_time
314
+ }
315
+
316
+ def generate_answer(
317
+ self,
318
+ question: str,
319
+ formatter: PromptFormatter,
320
+ top_k: int = 5,
321
+ metadata_filter: Optional[Dict] = None,
322
+ document_type: str = "general",
323
+ return_sub_queries: bool = False,
324
+ return_hypothetical: bool = False
325
+ ) -> Dict:
326
+ """
327
+ 完整的融合 RAG 流程:檢索 + 生成答案
328
+
329
+ Args:
330
+ question: 原始問題
331
+ formatter: Prompt 格式化器
332
+ top_k: 用於生成答案的文檔數量
333
+ metadata_filter: 可選的 metadata 過濾條件
334
+ document_type: 文檔類型 ("paper", "cv", "general")
335
+ return_sub_queries: 是否返回子問題列表
336
+ return_hypothetical: 是否返回假設性文檔字典
337
+
338
+ Returns:
339
+ 包含檢索結果、生成的答案和統計資訊的字典
340
+ """
341
+ start_time = time.time()
342
+
343
+ # 檢索
344
+ retrieval_result = self.query(
345
+ question=question,
346
+ top_k=top_k,
347
+ metadata_filter=metadata_filter,
348
+ return_sub_queries=return_sub_queries,
349
+ return_hypothetical=return_hypothetical
350
+ )
351
+
352
+ if not retrieval_result["results"]:
353
+ return {
354
+ **retrieval_result,
355
+ "answer": "抱歉,未找到相關文檔來回答此問題。",
356
+ "formatted_context": None,
357
+ "answer_time": 0.0,
358
+ "total_time": retrieval_result["elapsed_time"]
359
+ }
360
+
361
+ # 格式化上下文
362
+ formatted_context = formatter.format_context(
363
+ retrieval_result["results"],
364
+ document_type=document_type
365
+ )
366
+
367
+ # 創建 prompt(使用原始問題)
368
+ prompt = formatter.create_prompt(
369
+ question,
370
+ formatted_context,
371
+ document_type=document_type
372
+ )
373
+
374
+ # 生成回答
375
+ logger.info("🤖 生成回答中...")
376
+ answer_start = time.time()
377
+ try:
378
+ answer = self.llm.generate(
379
+ prompt=prompt,
380
+ temperature=0.7,
381
+ max_tokens=2048
382
+ )
383
+ answer_time = time.time() - answer_start
384
+ logger.info(f"✅ 回答生成完成(耗時: {answer_time:.2f}s)")
385
+ except Exception as e:
386
+ logger.error(f"❌ 生成回答時出錯: {e}")
387
+ answer = f"生成回答時出錯: {e}"
388
+ answer_time = time.time() - answer_start
389
+
390
+ total_time = time.time() - start_time
391
+
392
+ return {
393
+ **retrieval_result,
394
+ "answer": answer,
395
+ "formatted_context": formatted_context,
396
+ "answer_time": answer_time,
397
+ "total_time": total_time
398
+ }
399
+
src/hyde_rag.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HyDE (Hypothetical Document Embeddings) RAG:使用假設性文檔改善檢索
3
+ """
4
+ from typing import List, Dict, Optional
5
+ from .retrievers.reranker import RAGPipeline
6
+ from .retrievers.vector_retriever import VectorRetriever
7
+ from .prompt_formatter import PromptFormatter
8
+ from .llm_integration import OllamaLLM
9
+ import time
10
+ import logging
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class HyDERAG:
16
+ """使用 HyDE (Hypothetical Document Embeddings) 的 RAG 系統"""
17
+
18
+ def __init__(
19
+ self,
20
+ rag_pipeline: RAGPipeline,
21
+ vector_retriever: VectorRetriever,
22
+ llm: OllamaLLM,
23
+ hypothetical_length: int = 200,
24
+ temperature: float = 0.7
25
+ ):
26
+ """
27
+ 初始化 HyDE RAG
28
+
29
+ Args:
30
+ rag_pipeline: RAG 管線實例(用於最終答案生成)
31
+ vector_retriever: 向量檢索器(用於基於假設性文檔的檢索)
32
+ llm: LLM 實例(用於生成假設性文檔)
33
+ hypothetical_length: 假設性文檔的目標長度(字符數)
34
+ temperature: 生成假設性文檔時的溫度參數(建議 0.7,以獲得更多專業術語)
35
+ """
36
+ self.rag_pipeline = rag_pipeline
37
+ self.vector_retriever = vector_retriever
38
+ self.llm = llm
39
+ self.hypothetical_length = hypothetical_length
40
+ self.temperature = temperature
41
+
42
+ def _generate_hypothetical_document(self, question: str) -> str:
43
+ """
44
+ 生成假設性文檔(Hypothetical Document)
45
+
46
+ Args:
47
+ question: 用戶問題
48
+
49
+ Returns:
50
+ 假設性文檔文本
51
+ """
52
+ # 檢測語言
53
+ is_chinese = PromptFormatter.detect_language(question) == "zh"
54
+
55
+ if is_chinese:
56
+ prompt = f"""請針對以下問題,寫出一段約 {self.hypothetical_length} 字的專業技術檔案內容。
57
+ 這段內容應包含該領域常見的專業術語與原理說明,以便用於後續的語義檢索。
58
+ 請使用專業的術語和概念,即使你對某些細節不確定,也要包含相關的專業詞彙。
59
+
60
+ 問題: {question}
61
+
62
+ 專業技術內容:"""
63
+ else:
64
+ prompt = f"""Please write a professional technical document of approximately {self.hypothetical_length} words in response to the following question.
65
+ This content should include common professional terminology and principle explanations in this field, to be used for subsequent semantic retrieval.
66
+ Please use professional terms and concepts, and include relevant professional vocabulary even if you are uncertain about some details.
67
+
68
+ Question: {question}
69
+
70
+ Professional technical content:"""
71
+
72
+ try:
73
+ hypothetical_doc = self.llm.generate(
74
+ prompt=prompt,
75
+ temperature=self.temperature, # 較高的溫度以獲得更多專業術語
76
+ max_tokens=500
77
+ )
78
+
79
+ # 清理輸出
80
+ hypothetical_doc = hypothetical_doc.strip()
81
+
82
+ if not hypothetical_doc:
83
+ logger.warning("⚠️ 生成的假設性文檔為空,使用原始問題")
84
+ return question
85
+
86
+ logger.info(f"✅ 生成假設性文檔(長度: {len(hypothetical_doc)} 字符)")
87
+ return hypothetical_doc
88
+
89
+ except Exception as e:
90
+ logger.error(f"⚠️ 生成假設性文檔時出錯: {e}")
91
+ # 回退到使用原始問題
92
+ return question
93
+
94
+ def query(
95
+ self,
96
+ question: str,
97
+ top_k: int = 5,
98
+ metadata_filter: Optional[Dict] = None,
99
+ return_hypothetical: bool = False
100
+ ) -> Dict:
101
+ """
102
+ 執行 HyDE 檢索(不生成答案)
103
+
104
+ Args:
105
+ question: 原始問題
106
+ top_k: 返回前 k 個結果
107
+ metadata_filter: 可選的 metadata 過濾條件
108
+ return_hypothetical: 是否在結果中包含假設性文檔
109
+
110
+ Returns:
111
+ 包含檢索結果和統計資訊的字典
112
+ """
113
+ start_time = time.time()
114
+
115
+ # 第一步:生成假設性文檔
116
+ logger.info(f"🔍 生成假設性文檔: '{question}'")
117
+ hypothetical_doc = self._generate_hypothetical_document(question)
118
+
119
+ # 第二步:使用假設性文檔進行檢索
120
+ logger.info(f"📚 使用假設性文檔進行檢索...")
121
+ results = self.vector_retriever.retrieve(
122
+ query=hypothetical_doc, # 使用假設性文檔而不是原始問題
123
+ top_k=top_k,
124
+ metadata_filter=metadata_filter
125
+ )
126
+
127
+ elapsed_time = time.time() - start_time
128
+ logger.info(f"✅ 找到 {len(results)} 個結果(耗時: {elapsed_time:.2f}s)")
129
+
130
+ result = {
131
+ "results": results,
132
+ "total_docs_found": len(results),
133
+ "hypothetical_document": hypothetical_doc if return_hypothetical else None,
134
+ "elapsed_time": elapsed_time
135
+ }
136
+
137
+ return result
138
+
139
+ def generate_answer(
140
+ self,
141
+ question: str,
142
+ formatter: PromptFormatter,
143
+ top_k: int = 5,
144
+ metadata_filter: Optional[Dict] = None,
145
+ document_type: str = "general",
146
+ return_hypothetical: bool = False
147
+ ) -> Dict:
148
+ """
149
+ 完整的 HyDE RAG 流程:生成假設性文檔 -> 檢索 -> 生成答案
150
+
151
+ Args:
152
+ question: 原始問題
153
+ formatter: Prompt 格式化器
154
+ top_k: 用於生成答案的文檔數量
155
+ metadata_filter: 可選的 metadata 過濾條件
156
+ document_type: 文檔類型 ("paper", "cv", "general")
157
+ return_hypothetical: 是否在結果中包含假設性文檔
158
+
159
+ Returns:
160
+ 包含檢索結果、生成的答案和統計資訊的字典
161
+ """
162
+ start_time = time.time()
163
+
164
+ # 第一步:生成假設性文檔
165
+ logger.info(f"🔍 生成假設性文檔: '{question}'")
166
+ hypothetical_start = time.time()
167
+ hypothetical_doc = self._generate_hypothetical_document(question)
168
+ hypothetical_time = time.time() - hypothetical_start
169
+
170
+ # 第二步:使用假設性文檔進行檢索
171
+ logger.info(f"📚 使用假設性文檔進行檢索...")
172
+ retrieval_start = time.time()
173
+ results = self.vector_retriever.retrieve(
174
+ query=hypothetical_doc, # 使用假設性文檔而不是原始問題
175
+ top_k=top_k,
176
+ metadata_filter=metadata_filter
177
+ )
178
+ retrieval_time = time.time() - retrieval_start
179
+
180
+ if not results:
181
+ return {
182
+ "results": [],
183
+ "total_docs_found": 0,
184
+ "hypothetical_document": hypothetical_doc if return_hypothetical else None,
185
+ "elapsed_time": retrieval_time + hypothetical_time,
186
+ "answer": "抱歉,未找到相關文檔來回答此問題。",
187
+ "formatted_context": None,
188
+ "answer_time": 0.0,
189
+ "total_time": retrieval_time + hypothetical_time
190
+ }
191
+
192
+ # 第三步:格式化上下文
193
+ formatted_context = formatter.format_context(
194
+ results,
195
+ document_type=document_type
196
+ )
197
+
198
+ # 第四步:創建 prompt(使用原始問題,而不是假設性文檔)
199
+ prompt = formatter.create_prompt(
200
+ question, # 使用原始問題生成答案
201
+ formatted_context,
202
+ document_type=document_type
203
+ )
204
+
205
+ # 第五步:生成回答
206
+ logger.info("🤖 生成回答中...")
207
+ answer_start = time.time()
208
+ try:
209
+ answer = self.llm.generate(
210
+ prompt=prompt,
211
+ temperature=0.7,
212
+ max_tokens=2048
213
+ )
214
+ answer_time = time.time() - answer_start
215
+ logger.info(f"✅ 回答生成完成(耗時: {answer_time:.2f}s)")
216
+ except Exception as e:
217
+ logger.error(f"❌ 生成回答時出錯: {e}")
218
+ answer = f"生成回答時出錯: {e}"
219
+ answer_time = time.time() - answer_start
220
+
221
+ total_time = time.time() - start_time
222
+
223
+ return {
224
+ "results": results,
225
+ "total_docs_found": len(results),
226
+ "hypothetical_document": hypothetical_doc if return_hypothetical else None,
227
+ "elapsed_time": retrieval_time + hypothetical_time,
228
+ "hypothetical_time": hypothetical_time,
229
+ "retrieval_time": retrieval_time,
230
+ "answer": answer,
231
+ "formatted_context": formatted_context,
232
+ "answer_time": answer_time,
233
+ "total_time": total_time
234
+ }
235
+
src/llm_integration.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM 集成模組:使用 Ollama 進行本地 LLM 推理
3
+ """
4
+ from typing import Optional, Dict, List
5
+ import logging
6
+ import requests
7
+ import json
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class OllamaLLM:
13
+ """使用 Ollama 進行本地 LLM 推理"""
14
+
15
+ # 適合 16GB MacBook Air 的模型推薦
16
+ RECOMMENDED_MODELS = {
17
+ "deepseek-r1:7b": {
18
+ "name": "deepseek-r1:7b",
19
+ "description": "DeepSeek R1 7B - 大模型,高質量",
20
+ "memory_required": "~8GB",
21
+ "quality": "優秀"
22
+ },
23
+ "llama3.2:3b": {
24
+ "name": "llama3.2:3b",
25
+ "description": "Meta Llama 3.2 3B - 輕量級,適合 16GB 內存",
26
+ "memory_required": "~4GB",
27
+ "quality": "良好"
28
+ },
29
+ "llama3.2:1b": {
30
+ "name": "llama3.2:1b",
31
+ "description": "Meta Llama 3.2 1B - 極輕量級,快速響應",
32
+ "memory_required": "~2GB",
33
+ "quality": "基礎"
34
+ },
35
+ "phi3:mini": {
36
+ "name": "phi3:mini",
37
+ "description": "Microsoft Phi-3 Mini - 小模型,高質量",
38
+ "memory_required": "~3GB",
39
+ "quality": "良好"
40
+ },
41
+ "gemma:2b": {
42
+ "name": "gemma:2b",
43
+ "description": "Google Gemma 2B - 輕量級,開源",
44
+ "memory_required": "~3GB",
45
+ "quality": "良好"
46
+ },
47
+ "mistral:7b": {
48
+ "name": "mistral:7b",
49
+ "description": "Mistral 7B - 較大但質量高(如果內存足夠)",
50
+ "memory_required": "~8GB",
51
+ "quality": "優秀"
52
+ }
53
+ }
54
+
55
+ def __init__(
56
+ self,
57
+ model_name: str = "llama3.2:3b",
58
+ base_url: str = "http://localhost:11434",
59
+ timeout: int = 120
60
+ ):
61
+ """
62
+ 初始化 Ollama LLM
63
+
64
+ Args:
65
+ model_name: Ollama 模型名稱(預設: llama3.2:3b)
66
+ base_url: Ollama API 基礎 URL
67
+ timeout: 請求超時時間(秒)
68
+ """
69
+ self.model_name = model_name
70
+ self.base_url = base_url.rstrip('/')
71
+ self.timeout = timeout
72
+ self.api_url = f"{self.base_url}/api"
73
+
74
+ # 檢查模型是否在推薦列表中
75
+ if model_name not in self.RECOMMENDED_MODELS:
76
+ logger.warning(
77
+ f"⚠️ 模型 '{model_name}' 不在推薦列表中。"
78
+ f"推薦的模型: {', '.join(self.RECOMMENDED_MODELS.keys())}"
79
+ )
80
+
81
+ logger.info(f"✅ Ollama LLM 初始化完成 (模型: {model_name})")
82
+
83
+ def _check_ollama_connection(self) -> bool:
84
+ """
85
+ 檢查 Ollama 服務是否可用
86
+
87
+ Returns:
88
+ 是否連接成功
89
+ """
90
+ try:
91
+ response = requests.get(f"{self.base_url}/api/tags", timeout=5)
92
+ return response.status_code == 200
93
+ except Exception as e:
94
+ logger.error(f"❌ 無法連接到 Ollama: {e}")
95
+ logger.error(f" 請確保 Ollama 正在運行: ollama serve")
96
+ return False
97
+
98
+ def _check_model_available(self) -> bool:
99
+ """
100
+ 檢查模型是否已下載
101
+
102
+ Returns:
103
+ 模型是否可用
104
+ """
105
+ try:
106
+ response = requests.get(f"{self.base_url}/api/tags", timeout=5)
107
+ if response.status_code == 200:
108
+ models = response.json().get('models', [])
109
+ model_names = [m.get('name', '') for m in models]
110
+ return any(self.model_name in name for name in model_names)
111
+ return False
112
+ except Exception as e:
113
+ logger.error(f"❌ 檢查模型時出錯: {e}")
114
+ return False
115
+
116
+ def generate(
117
+ self,
118
+ prompt: str,
119
+ temperature: float = 0.7,
120
+ max_tokens: Optional[int] = None,
121
+ stream: bool = False
122
+ ) -> str:
123
+ """
124
+ 生成回答
125
+
126
+ Args:
127
+ prompt: 輸入 prompt
128
+ temperature: 溫度參數(0.0-1.0),控制隨機性
129
+ max_tokens: 最大生成 token 數(None 表示使用模型預設)
130
+ stream: 是否使用流式輸出
131
+
132
+ Returns:
133
+ 生成的回答
134
+ """
135
+ # 檢查連接
136
+ if not self._check_ollama_connection():
137
+ raise ConnectionError(
138
+ f"無法連接到 Ollama 服務 ({self.base_url})\n"
139
+ f"請確保 Ollama 正在運行:\n"
140
+ f" 1. 安裝 Ollama: https://ollama.ai\n"
141
+ f" 2. 啟動服務: ollama serve\n"
142
+ f" 3. 下載模型: ollama pull {self.model_name}"
143
+ )
144
+
145
+ # 檢查模型
146
+ if not self._check_model_available():
147
+ logger.warning(
148
+ f"⚠️ 模型 '{self.model_name}' 可能未下載。"
149
+ f"請運行: ollama pull {self.model_name}"
150
+ )
151
+
152
+ # 準備請求參數
153
+ payload = {
154
+ "model": self.model_name,
155
+ "prompt": prompt,
156
+ "stream": stream,
157
+ "options": {
158
+ "temperature": temperature,
159
+ }
160
+ }
161
+
162
+ if max_tokens:
163
+ payload["options"]["num_predict"] = max_tokens
164
+
165
+ try:
166
+ # 發送請求
167
+ response = requests.post(
168
+ f"{self.api_url}/generate",
169
+ json=payload,
170
+ timeout=self.timeout,
171
+ stream=stream
172
+ )
173
+
174
+ if response.status_code != 200:
175
+ error_msg = response.text
176
+ raise RuntimeError(f"Ollama API 錯誤: {error_msg}")
177
+
178
+ if stream:
179
+ # 流式處理
180
+ full_response = ""
181
+ for line in response.iter_lines():
182
+ if line:
183
+ try:
184
+ data = json.loads(line)
185
+ if 'response' in data:
186
+ chunk = data['response']
187
+ full_response += chunk
188
+ print(chunk, end='', flush=True)
189
+ if data.get('done', False):
190
+ break
191
+ except json.JSONDecodeError:
192
+ continue
193
+ print() # 換行
194
+ return full_response
195
+ else:
196
+ # 非流式處理
197
+ data = response.json()
198
+ return data.get('response', '')
199
+
200
+ except requests.exceptions.Timeout:
201
+ raise TimeoutError(
202
+ f"請求超時({self.timeout}秒)。"
203
+ f"可以嘗試增加 timeout 或使用更小的模型。"
204
+ )
205
+ except requests.exceptions.ConnectionError:
206
+ raise ConnectionError(
207
+ f"無法連接到 Ollama 服務。"
208
+ f"請確保 Ollama 正在運行:ollama serve"
209
+ )
210
+ except Exception as e:
211
+ logger.error(f"❌ 生成回答時出錯: {e}")
212
+ raise
213
+
214
+ def list_available_models(self) -> List[str]:
215
+ """
216
+ 列出本地可用的模型
217
+
218
+ Returns:
219
+ 可用模型名稱列表
220
+ """
221
+ try:
222
+ response = requests.get(f"{self.base_url}/api/tags", timeout=5)
223
+ if response.status_code == 200:
224
+ models = response.json().get('models', [])
225
+ return [m.get('name', '') for m in models]
226
+ return []
227
+ except Exception as e:
228
+ logger.error(f"❌ 獲取模型列表時出錯: {e}")
229
+ return []
230
+
231
+ @classmethod
232
+ def print_recommended_models(cls):
233
+ """打印推薦的模型列表"""
234
+ print("\n" + "="*60)
235
+ print("適合 16GB MacBook Air 的 Ollama 模型推薦")
236
+ print("="*60)
237
+ print()
238
+
239
+ for model_key, info in cls.RECOMMENDED_MODELS.items():
240
+ print(f"📦 {info['name']}")
241
+ print(f" 描述: {info['description']}")
242
+ print(f" 內存需求: {info['memory_required']}")
243
+ print(f" 質量: {info['quality']}")
244
+ print(f" 下載命令: ollama pull {info['name']}")
245
+ print()
246
+
src/prompt_formatter.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prompt 格式化模組:將檢索結果格式化為 LLM 可讀的上下文
3
+ """
4
+ from typing import List, Dict, Optional
5
+ import re
6
+
7
+
8
+ class PromptFormatter:
9
+ """格式化檢索結果供 LLM 使用"""
10
+
11
+ def __init__(
12
+ self,
13
+ include_metadata: bool = True,
14
+ format_style: str = "detailed",
15
+ max_context_length: Optional[int] = None,
16
+ auto_detect_language: bool = True
17
+ ):
18
+ """
19
+ 初始化 Prompt 格式化器
20
+
21
+ Args:
22
+ include_metadata: 是否包含來源資訊
23
+ format_style: 格式風格 ("detailed", "simple", "minimal")
24
+ max_context_length: 最大上下文長度(字符數),None 表示不限制
25
+ auto_detect_language: 是否自動檢測語言並相應調整回答語言
26
+ """
27
+ self.include_metadata = include_metadata
28
+ self.format_style = format_style
29
+ self.max_context_length = max_context_length
30
+ self.auto_detect_language = auto_detect_language
31
+
32
+ @staticmethod
33
+ def detect_language(text: str) -> str:
34
+ """
35
+ 檢測文本的主要語言
36
+
37
+ Args:
38
+ text: 輸入文本
39
+
40
+ Returns:
41
+ "zh" 表示中文,"en" 表示英文
42
+ """
43
+ # 檢查是否包含中文字符(CJK 統一表意文字範圍)
44
+ chinese_pattern = re.compile(r'[\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff]')
45
+ chinese_chars = len(chinese_pattern.findall(text))
46
+
47
+ # 計算中文字符比例
48
+ total_chars = len([c for c in text if c.isalnum() or c.isspace()])
49
+
50
+ if total_chars == 0:
51
+ return "en" # 預設英文
52
+
53
+ chinese_ratio = chinese_chars / total_chars if total_chars > 0 else 0
54
+
55
+ # 如果中文字符比例超過 20%,認為是中文
56
+ if chinese_ratio > 0.2:
57
+ return "zh"
58
+ else:
59
+ return "en"
60
+
61
+ def get_system_prompt(self, language: str = "zh", document_type: str = "general") -> str:
62
+ """
63
+ 根據語言和文檔類型獲取系統提示詞
64
+
65
+ Args:
66
+ language: 語言代碼 ("zh" 或 "en")
67
+ document_type: 文檔類型 ("paper", "cv", "general")
68
+ "paper": 學術論文
69
+ "cv": 履歷/履歷
70
+ "general": 通用文檔(預設)
71
+
72
+ Returns:
73
+ 系統提示詞字符串
74
+ """
75
+ if language == "zh":
76
+ if document_type == "paper":
77
+ return (
78
+ "你是一個專業的 AI 研究助手,專門回答關於機器學習、"
79
+ "深度學習和自然語言處理的問題。\n\n"
80
+ "請基於以下提供的學術論文片段來回答用戶的問題。"
81
+ "每個片段都標註了來源論文的資訊。\n\n"
82
+ "回答要求:\n"
83
+ "1. 基於提供的上下文回答問題\n"
84
+ "2. 如果上下文不足以回答,請明確說明\n"
85
+ "3. 在回答中引用具體的論文來源(使用 arXiv ID)\n"
86
+ "4. 如果不同論文有不同觀點,請分別說明\n"
87
+ "5. 保持回答簡潔、準確、專業\n"
88
+ "6. **重要:請使用與用戶問題相同的語言回答**\n"
89
+ )
90
+ elif document_type == "cv":
91
+ return (
92
+ "你是一個專業的 AI 助手,專門幫助分析和介紹簡歷(CV)內容。\n\n"
93
+ "請基於以下提供的文檔片段來回答用戶的問題。"
94
+ "這些片段來自一份簡歷或履歷表。\n\n"
95
+ "回答要求:\n"
96
+ "1. 基於提供的上下文回答問題\n"
97
+ "2. 如果上下文不足以回答,請明確說明\n"
98
+ "3. 在回答中引用具體的文檔內容\n"
99
+ "4. 保持回答簡潔、準確、專業\n"
100
+ "5. **重要:請使用與用戶問題相同的語言回答**\n"
101
+ "6. **請理解:這些片段就是簡歷的內容,請直接基於這些內容回答問題**\n"
102
+ )
103
+ else: # general
104
+ return (
105
+ "你是一個專業的 AI 助手。\n\n"
106
+ "請基於以下提供的文檔片段來回答用戶的問題。"
107
+ "每個片段都標註了來源資訊。\n\n"
108
+ "回答要求:\n"
109
+ "1. 基於提供的上下文回答問題\n"
110
+ "2. 如果上下文不足以回答,請明確說明\n"
111
+ "3. 在回答中引用具體的文檔內容\n"
112
+ "4. 保持回答簡潔、準確、專業\n"
113
+ "5. **重要:請使用與用戶問題相同的語言回答**\n"
114
+ )
115
+ else: # English
116
+ if document_type == "paper":
117
+ return (
118
+ "You are a professional AI research assistant specializing in "
119
+ "machine learning, deep learning, and natural language processing.\n\n"
120
+ "Please answer the user's question based on the provided academic paper excerpts. "
121
+ "Each excerpt is labeled with source paper information.\n\n"
122
+ "Answer requirements:\n"
123
+ "1. Answer the question based on the provided context\n"
124
+ "2. If the context is insufficient, clearly state so\n"
125
+ "3. Cite specific paper sources in your answer (using arXiv ID)\n"
126
+ "4. If different papers have different viewpoints, explain them separately\n"
127
+ "5. Keep answers concise, accurate, and professional\n"
128
+ "6. **Important: Please answer in the same language as the user's question**\n"
129
+ )
130
+ elif document_type == "cv":
131
+ return (
132
+ "You are a professional AI assistant specializing in analyzing and introducing CV (Curriculum Vitae) content.\n\n"
133
+ "Please answer the user's question based on the provided document excerpts. "
134
+ "These excerpts are from a CV or resume.\n\n"
135
+ "Answer requirements:\n"
136
+ "1. Answer the question based on the provided context\n"
137
+ "2. If the context is insufficient, clearly state so\n"
138
+ "3. Cite specific document content in your answer\n"
139
+ "4. Keep answers concise, accurate, and professional\n"
140
+ "5. **Important: Please answer in the same language as the user's question**\n"
141
+ "6. **Please understand: These excerpts ARE the CV content. Please answer directly based on this content.**\n"
142
+ )
143
+ else: # general
144
+ return (
145
+ "You are a professional AI assistant.\n\n"
146
+ "Please answer the user's question based on the provided document excerpts. "
147
+ "Each excerpt is labeled with source information.\n\n"
148
+ "Answer requirements:\n"
149
+ "1. Answer the question based on the provided context\n"
150
+ "2. If the context is insufficient, clearly state so\n"
151
+ "3. Cite specific document content in your answer\n"
152
+ "4. Keep answers concise, accurate, and professional\n"
153
+ "5. **Important: Please answer in the same language as the user's question**\n"
154
+ )
155
+
156
+ def format_context(
157
+ self,
158
+ results: List[Dict],
159
+ include_metadata: Optional[bool] = None,
160
+ format_style: Optional[str] = None,
161
+ document_type: str = "general"
162
+ ) -> str:
163
+ """
164
+ 格式化檢索結果為 LLM 可讀的上下文
165
+
166
+ Args:
167
+ results: 檢索結果列表
168
+ include_metadata: 是否包含來源資訊(覆蓋初始化參數)
169
+ format_style: 格式風格(覆蓋初始化參數)
170
+ document_type: 文檔類型 ("paper", "cv", "general"),用於調整格式
171
+
172
+ Returns:
173
+ 格式化後的上下文字符串
174
+ """
175
+ if include_metadata is None:
176
+ include_metadata = self.include_metadata
177
+ if format_style is None:
178
+ format_style = self.format_style
179
+
180
+ if not results:
181
+ # 根據格式風格選擇語言
182
+ if format_style == "detailed" or format_style == "simple":
183
+ return "(未找到相關文檔片段)"
184
+ else:
185
+ return "(No relevant excerpts found)"
186
+
187
+ formatted_parts = []
188
+
189
+ for i, result in enumerate(results, 1):
190
+ content = result.get("content", "")
191
+ metadata = result.get("metadata", {})
192
+
193
+ if not include_metadata:
194
+ # 不包含來源資訊,直接使用內容
195
+ formatted_parts.append(f"{content}\n")
196
+ elif format_style == "detailed":
197
+ # 詳細格式:根據文檔類型調整顯示資訊
198
+ if document_type == "cv":
199
+ # CV 格式:顯示檔案名和路徑
200
+ source_info = (
201
+ f"[來源 {i}]\n"
202
+ f"檔案標題: {metadata.get('title', 'N/A')}\n"
203
+ )
204
+ if 'file_path' in metadata:
205
+ source_info += f"檔案路徑: {metadata.get('file_path', 'N/A')}\n"
206
+ if 'file_type' in metadata:
207
+ source_info += f"檔案類型: {metadata.get('file_type', 'N/A')}\n"
208
+ elif document_type == "paper":
209
+ # 論文格式:顯示論文資訊
210
+ authors = metadata.get('authors', [])
211
+ if isinstance(authors, str):
212
+ authors_str = authors
213
+ elif isinstance(authors, list):
214
+ authors_str = ', '.join(authors[:3]) # 最多顯示 3 個作者
215
+ if len(authors) > 3:
216
+ authors_str += f" 等 {len(authors)} 位作者"
217
+ else:
218
+ authors_str = 'N/A'
219
+
220
+ source_info = (
221
+ f"[來源 {i}]\n"
222
+ f"論文標題: {metadata.get('title', 'N/A')}\n"
223
+ f"arXiv ID: {metadata.get('arxiv_id', 'N/A')}\n"
224
+ f"作者: {authors_str}\n"
225
+ f"發布日期: {metadata.get('published', 'N/A')}\n"
226
+ )
227
+ else:
228
+ # 通用格式:顯示可用的資訊
229
+ source_info = f"[來源 {i}]\n"
230
+ if 'title' in metadata:
231
+ source_info += f"標題: {metadata.get('title', 'N/A')}\n"
232
+ if 'file_path' in metadata:
233
+ source_info += f"檔案: {metadata.get('file_path', 'N/A')}\n"
234
+ if 'arxiv_id' in metadata:
235
+ source_info += f"arXiv ID: {metadata.get('arxiv_id', 'N/A')}\n"
236
+
237
+ # 添加相關性分數(如果有的話)
238
+ rerank_score = result.get('rerank_score')
239
+ hybrid_score = result.get('hybrid_score')
240
+ if rerank_score is not None:
241
+ source_info += f"相關性分數: {rerank_score:.4f}\n"
242
+ elif hybrid_score is not None:
243
+ source_info += f"相關性分數: {hybrid_score:.4f}\n"
244
+
245
+ source_info += f"---\n{content}\n"
246
+ formatted_parts.append(source_info)
247
+
248
+ elif format_style == "simple":
249
+ # 簡單格式:只包含關鍵資訊
250
+ title = metadata.get('title', 'N/A')
251
+ if document_type == "paper" and 'arxiv_id' in metadata:
252
+ arxiv_id = metadata.get('arxiv_id', 'N/A')
253
+ source_info = (
254
+ f"[來源 {i}: {title} "
255
+ f"(arXiv:{arxiv_id})]\n"
256
+ f"{content}\n"
257
+ )
258
+ elif document_type == "cv" and 'file_path' in metadata:
259
+ file_path = metadata.get('file_path', 'N/A')
260
+ source_info = (
261
+ f"[來源 {i}: {title} "
262
+ f"({file_path})]\n"
263
+ f"{content}\n"
264
+ )
265
+ else:
266
+ source_info = (
267
+ f"[來源 {i}: {title}]\n"
268
+ f"{content}\n"
269
+ )
270
+ formatted_parts.append(source_info)
271
+ else: # minimal
272
+ # 最小格式:只標註來源
273
+ if document_type == "paper" and 'arxiv_id' in metadata:
274
+ arxiv_id = metadata.get('arxiv_id', 'N/A')
275
+ source_info = (
276
+ f"[arXiv:{arxiv_id}]\n"
277
+ f"{content}\n"
278
+ )
279
+ elif 'title' in metadata:
280
+ title = metadata.get('title', 'N/A')
281
+ source_info = (
282
+ f"[{title}]\n"
283
+ f"{content}\n"
284
+ )
285
+ else:
286
+ source_info = (
287
+ f"[來源 {i}]\n"
288
+ f"{content}\n"
289
+ )
290
+ formatted_parts.append(source_info)
291
+
292
+ formatted_text = "\n" + "="*60 + "\n".join(formatted_parts)
293
+
294
+ # 如果設置了最大長度,進行截斷
295
+ if self.max_context_length and len(formatted_text) > self.max_context_length:
296
+ # 從後往前截斷,保留格式
297
+ formatted_text = formatted_text[:self.max_context_length]
298
+ # 確保最後一個來源資訊完整
299
+ last_separator = formatted_text.rfind("="*60)
300
+ if last_separator > 0:
301
+ formatted_text = formatted_text[:last_separator] + "\n(內容已截斷...)"
302
+
303
+ return formatted_text
304
+
305
+ def create_prompt(
306
+ self,
307
+ query: str,
308
+ context: str,
309
+ system_prompt: Optional[str] = None,
310
+ document_type: str = "general"
311
+ ) -> str:
312
+ """
313
+ 創建完整的 LLM prompt
314
+
315
+ Args:
316
+ query: 用戶查詢
317
+ context: 格式化後的上下文
318
+ system_prompt: 可選的系統提示詞(如果為 None,會根據語言和文檔類型自動選擇)
319
+ document_type: 文檔類型 ("paper", "cv", "general")
320
+
321
+ Returns:
322
+ 完整的 prompt 字符串
323
+ """
324
+ # 自動檢測語言並選擇相應的系統提示詞
325
+ if system_prompt is None and self.auto_detect_language:
326
+ detected_language = self.detect_language(query)
327
+ system_prompt = self.get_system_prompt(detected_language, document_type)
328
+ elif system_prompt is None:
329
+ # 如果禁用自動檢測,使用中文作為預設
330
+ system_prompt = self.get_system_prompt("zh", document_type)
331
+
332
+ # 根據檢測到的語言選擇提示詞格式
333
+ detected_language = self.detect_language(query) if self.auto_detect_language else "zh"
334
+
335
+ # 根據文檔類型選擇不同的提示詞結尾
336
+ if document_type == "paper":
337
+ if detected_language == "zh":
338
+ ending = "## 請基於上述文獻片段回答問題,並在回答中引用具體的論文來源。"
339
+ else:
340
+ ending = "## Please answer the question based on the above document excerpts and cite specific paper sources in your answer."
341
+ else:
342
+ if detected_language == "zh":
343
+ ending = "## 請基於上述文檔片段回答問題,並在回答中引用具體的文檔內容。"
344
+ else:
345
+ ending = "## Please answer the question based on the above document excerpts and cite specific document content in your answer."
346
+
347
+ if detected_language == "zh":
348
+ prompt = f"""{system_prompt}
349
+
350
+ ## 相關文檔片段:
351
+
352
+ {context}
353
+
354
+ ## 用戶問題:
355
+
356
+ {query}
357
+
358
+ {ending}"""
359
+ else: # English
360
+ prompt = f"""{system_prompt}
361
+
362
+ ## Relevant Document Excerpts:
363
+
364
+ {context}
365
+
366
+ ## User Question:
367
+
368
+ {query}
369
+
370
+ {ending}"""
371
+
372
+ return prompt
373
+
374
+ def format_for_llm(
375
+ self,
376
+ query: str,
377
+ results: List[Dict],
378
+ system_prompt: Optional[str] = None,
379
+ document_type: str = "general"
380
+ ) -> str:
381
+ """
382
+ 一站式方法:格式化檢索結果並創建完整的 prompt
383
+
384
+ Args:
385
+ query: 用戶查詢
386
+ results: 檢索結果列表
387
+ system_prompt: 可選的系統提示詞
388
+ document_type: 文檔類型 ("paper", "cv", "general")
389
+
390
+ Returns:
391
+ 完整的 prompt 字符串
392
+ """
393
+ context = self.format_context(results, document_type=document_type)
394
+ return self.create_prompt(query, context, system_prompt, document_type)
395
+
src/retrievers/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 檢索器模組
3
+ """
4
+ from .base import BaseRetriever
5
+ from .bm25_retriever import BM25Retriever
6
+ from .vector_retriever import VectorRetriever
7
+ from .hybrid_search import HybridSearch
8
+ from .reranker import Reranker, RAGPipeline
9
+
10
+ __all__ = [
11
+ "BaseRetriever",
12
+ "BM25Retriever",
13
+ "VectorRetriever",
14
+ "HybridSearch",
15
+ "Reranker",
16
+ "RAGPipeline",
17
+ ]
src/retrievers/base.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 檢索器模組的抽象基類
3
+ """
4
+ from abc import ABC, abstractmethod
5
+ from typing import List, Dict, Optional
6
+
7
+
8
+ class BaseRetriever(ABC):
9
+ """檢索器的抽象基類"""
10
+
11
+ @abstractmethod
12
+ def retrieve(
13
+ self,
14
+ query: str,
15
+ top_k: int = 5,
16
+ metadata_filter: Optional[Dict] = None
17
+ ) -> List[Dict]:
18
+ """
19
+ 檢索相關文檔並返回帶有分數的結果。
20
+
21
+ Args:
22
+ query: 查詢文字
23
+ top_k: 返回前 k 個結果
24
+ metadata_filter: 可選的 metadata 過濾條件字典。
25
+ 例如: {"arxiv_id": "1234.5678"} 或 {"title": "Machine Learning"}
26
+ 支援多個條件,所有條件必須同時滿足(AND 邏輯)
27
+
28
+ Returns:
29
+ 相關文檔列表,每個文檔字典都應包含 "score" 鍵,
30
+ 且分數越高代表越相關。返回的結果會根據 metadata_filter 進行過濾。
31
+ """
32
+ pass
src/retrievers/bm25_retriever.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BM25 檢索器模組
3
+ """
4
+ from typing import List, Dict, Optional
5
+ from rank_bm25 import BM25Okapi
6
+ import re
7
+ from .base import BaseRetriever
8
+
9
+
10
+ class BM25Retriever(BaseRetriever):
11
+ """使用 BM25 演算法進行關鍵字檢索"""
12
+
13
+ def __init__(self, documents: List[Dict]):
14
+ """
15
+ 初始化 BM25 檢索器
16
+
17
+ Args:
18
+ documents: 文檔列表,每個文檔包含 "content" 和 "metadata"
19
+ """
20
+ self.documents = documents
21
+ self.texts = [doc["content"] for doc in documents]
22
+
23
+ # 對文字進行 tokenization(簡單的分詞)
24
+ tokenized_texts = [self._tokenize(text) for text in self.texts]
25
+
26
+ # 初始化 BM25
27
+ self.bm25 = BM25Okapi(tokenized_texts)
28
+
29
+ def _tokenize(self, text: str) -> List[str]:
30
+ """
31
+ 將文字轉換為 tokens(簡單的實作)
32
+
33
+ Args:
34
+ text: 輸入文字
35
+
36
+ Returns:
37
+ token 列表
38
+ """
39
+ # 轉為小寫並分割
40
+ text = text.lower()
41
+ # 使用正則表達式分割(保留字母和數字)
42
+ tokens = re.findall(r'\b\w+\b', text)
43
+ return tokens
44
+
45
+ def retrieve(
46
+ self,
47
+ query: str,
48
+ top_k: int = 5,
49
+ metadata_filter: Optional[Dict] = None
50
+ ) -> List[Dict]:
51
+ """
52
+ 檢索相關文檔,支援根據 metadata 進行過濾。
53
+
54
+ Args:
55
+ query: 查詢文字
56
+ top_k: 返回前 k 個結果
57
+ metadata_filter: 可選的 metadata 過濾條件字典。
58
+ 例如: {"arxiv_id": "1234.5678"} 只檢索特定論文的 chunks
59
+ 或 {"title": "Machine Learning"} 只檢索特定標題的論文
60
+ 支援多個條件,所有條件必須同時滿足(AND 邏輯)
61
+ 注意:BM25 的過濾是在檢索後進行的,所以可能會返回少於 top_k 的結果
62
+
63
+ Returns:
64
+ 相關文檔列表,每個包含 "content", "metadata", "score"
65
+ 結果會根據 metadata_filter 進行過濾
66
+ """
67
+ # Tokenize 查詢
68
+ tokenized_query = self._tokenize(query)
69
+
70
+ # 計算 BM25 分數
71
+ scores = self.bm25.get_scores(tokenized_query)
72
+
73
+ # 獲取所有結果並排序(先獲取更多結果以應對過濾後可能減少的情況)
74
+ # 如果沒有過濾條件,只需要 top_k 個;如果有過濾條件,需要更多候選結果
75
+ candidate_k = top_k * 3 if metadata_filter else top_k
76
+
77
+ # 獲取候選結果索引(按分數降序排列)
78
+ sorted_indices = sorted(
79
+ range(len(scores)),
80
+ key=lambda i: scores[i],
81
+ reverse=True
82
+ )[:candidate_k]
83
+
84
+ # 構建候選結果
85
+ candidate_results = []
86
+ for idx in sorted_indices:
87
+ candidate_results.append({
88
+ "content": self.documents[idx]["content"],
89
+ "metadata": self.documents[idx]["metadata"],
90
+ "score": float(scores[idx]),
91
+ })
92
+
93
+ # 如果提供了 metadata_filter,則進行過濾
94
+ if metadata_filter:
95
+ filtered_results = []
96
+ for result in candidate_results:
97
+ # 檢查該結果的 metadata 是否滿足所有過濾條件
98
+ metadata = result.get("metadata", {})
99
+ matches_all = True
100
+
101
+ for filter_key, filter_value in metadata_filter.items():
102
+ # 獲取文檔中對應的 metadata 值
103
+ doc_value = metadata.get(filter_key)
104
+
105
+ # 檢查是否匹配
106
+ # 支援精確匹配和部分匹配(如果 filter_value 是字串且 doc_value 也是字串)
107
+ if isinstance(filter_value, str) and isinstance(doc_value, str):
108
+ # 字串匹配:支援精確匹配或包含匹配
109
+ if filter_value.lower() not in doc_value.lower():
110
+ matches_all = False
111
+ break
112
+ else:
113
+ # 其他類型(數字、布林值等)使用精確匹配
114
+ if doc_value != filter_value:
115
+ matches_all = False
116
+ break
117
+
118
+ # 如果所有條件都滿足,則加入結果
119
+ if matches_all:
120
+ filtered_results.append(result)
121
+
122
+ # 返回過濾後的結果(最多 top_k 個)
123
+ return filtered_results[:top_k]
124
+ else:
125
+ # 沒有過濾條件,直接返回候選結果
126
+ return candidate_results
127
+
src/retrievers/hybrid_search.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hybrid Search 模組:結合 BM25 和向量檢索
3
+ 支援兩種融合方法:加權求和(Weighted Sum)和倒數排名融合(RRF)
4
+ """
5
+ from typing import List, Dict, Optional, Literal
6
+ from .base import BaseRetriever
7
+ import numpy as np
8
+
9
+
10
+ class HybridSearch(BaseRetriever):
11
+ """結合稀疏和密集檢索的混合搜尋"""
12
+
13
+ def __init__(
14
+ self,
15
+ sparse_retriever: BaseRetriever,
16
+ dense_retriever: BaseRetriever,
17
+ sparse_weight: float = 0.4,
18
+ dense_weight: float = 0.6,
19
+ fusion_method: Literal["weighted_sum", "rrf"] = "rrf",
20
+ rrf_k: int = 60,
21
+ ):
22
+ """
23
+ 初始化 Hybrid Search
24
+
25
+ Args:
26
+ sparse_retriever: 稀疏檢索器 (例如 BM25)
27
+ dense_retriever: 密集檢索器 (例如向量檢索)
28
+ sparse_weight: 稀疏檢索分數的權重(僅用於 weighted_sum 方法)
29
+ dense_weight: 密集檢索分數的權重(僅用於 weighted_sum 方法)
30
+ fusion_method: 融合方法,可選 "weighted_sum" 或 "rrf"
31
+ - "weighted_sum": 加權求和,需要正規化分數並設置權重
32
+ - "rrf": 倒數排名融合(Reciprocal Rank Fusion),
33
+ 不需要分數正規化,對不同分數分佈更魯棒
34
+ rrf_k: RRF 方法中的常數 k,通常設為 60(僅用於 rrf 方法)
35
+ 較大的 k 值會讓排名較低的文檔獲得更多權重
36
+ """
37
+ self.sparse_retriever = sparse_retriever
38
+ self.dense_retriever = dense_retriever
39
+ self.fusion_method = fusion_method
40
+ self.rrf_k = rrf_k
41
+
42
+ # 僅在 weighted_sum 方法中使用權重
43
+ if fusion_method == "weighted_sum":
44
+ self.sparse_weight = sparse_weight
45
+ self.dense_weight = dense_weight
46
+
47
+ # 確保權重總和為 1
48
+ total_weight = sparse_weight + dense_weight
49
+ if abs(total_weight - 1.0) > 1e-6:
50
+ self.sparse_weight = sparse_weight / total_weight
51
+ self.dense_weight = dense_weight / total_weight
52
+
53
+ def _normalize_scores(self, results: List[Dict]) -> List[Dict]:
54
+ """
55
+ 將分數正規化到 [0, 1] 區間。
56
+ 僅用於 weighted_sum 方法。
57
+
58
+ Args:
59
+ results: 檢索結果列表,每個字典包含 'score'
60
+
61
+ Returns:
62
+ 帶有正規化分數的結果列表
63
+ """
64
+ scores = [res.get("score", 0.0) for res in results]
65
+ if not scores:
66
+ return results
67
+
68
+ scores_array = np.array(scores)
69
+ min_score = scores_array.min()
70
+ max_score = scores_array.max()
71
+
72
+ if max_score == min_score:
73
+ # 如果所有分數都相同,將它們設置為 1.0
74
+ normalized_scores = [1.0] * len(scores)
75
+ else:
76
+ normalized_scores = ((scores_array - min_score) / (max_score - min_score)).tolist()
77
+
78
+ for i, res in enumerate(results):
79
+ res["score"] = normalized_scores[i]
80
+
81
+ return results
82
+
83
+ def _get_doc_id(self, doc: Dict) -> str:
84
+ """
85
+ 從文檔中提取唯一標識符
86
+
87
+ Args:
88
+ doc: 文檔字典
89
+
90
+ Returns:
91
+ 文檔的唯一 ID
92
+ """
93
+ metadata = doc.get("metadata", {})
94
+ return f"{metadata.get('arxiv_id', 'unknown')}_{metadata.get('chunk_index', 0)}"
95
+
96
+ def _apply_rrf(
97
+ self,
98
+ sparse_results: List[Dict],
99
+ dense_results: List[Dict]
100
+ ) -> List[Dict]:
101
+ """
102
+ 應用倒數排名融合(Reciprocal Rank Fusion, RRF)方法
103
+
104
+ RRF 公式:RRF(d) = Σ(1 / (k + rank_i(d)))
105
+ 其中:
106
+ - d 是文檔
107
+ - rank_i(d) 是文檔在第 i 個檢索結果中的排名(從 1 開始)
108
+ - k 是常數(預設為 60)
109
+
110
+ RRF 的優點:
111
+ 1. 不需要分數正規化,對不同分數分佈的檢索器更魯棒
112
+ 2. 只依賴排名位置,不依賴分數值
113
+ 3. 自動處理分數分佈差異的問題
114
+
115
+ Args:
116
+ sparse_results: 稀疏檢索結果列表
117
+ dense_results: 密集檢索結果列表
118
+
119
+ Returns:
120
+ 融合後的結果列表,按 RRF 分數排序
121
+ """
122
+ # 建立文檔 ID 到 RRF 分數的映射
123
+ doc_to_rrf_score = {}
124
+
125
+ # 處理稀疏檢索結果(BM25)
126
+ for rank, result in enumerate(sparse_results, start=1):
127
+ doc_id = self._get_doc_id(result)
128
+ if doc_id not in doc_to_rrf_score:
129
+ doc_to_rrf_score[doc_id] = {
130
+ "doc": result,
131
+ "rrf_score": 0.0,
132
+ "sparse_rank": None,
133
+ "dense_rank": None
134
+ }
135
+ # 計算 RRF 貢獻:1 / (k + rank)
136
+ doc_to_rrf_score[doc_id]["rrf_score"] += 1.0 / (self.rrf_k + rank)
137
+ doc_to_rrf_score[doc_id]["sparse_rank"] = rank
138
+
139
+ # 處理密集檢索結果(向量)
140
+ for rank, result in enumerate(dense_results, start=1):
141
+ doc_id = self._get_doc_id(result)
142
+ if doc_id not in doc_to_rrf_score:
143
+ doc_to_rrf_score[doc_id] = {
144
+ "doc": result,
145
+ "rrf_score": 0.0,
146
+ "sparse_rank": None,
147
+ "dense_rank": None
148
+ }
149
+ # 計算 RRF 貢獻:1 / (k + rank)
150
+ doc_to_rrf_score[doc_id]["rrf_score"] += 1.0 / (self.rrf_k + rank)
151
+ doc_to_rrf_score[doc_id]["dense_rank"] = rank
152
+
153
+ # 構建結果列表
154
+ rrf_results = []
155
+ for doc_id, data in doc_to_rrf_score.items():
156
+ result = data["doc"].copy()
157
+ result["hybrid_score"] = data["rrf_score"]
158
+ result["rrf_score"] = data["rrf_score"]
159
+ result["sparse_rank"] = data["sparse_rank"]
160
+ result["dense_rank"] = data["dense_rank"]
161
+
162
+ # 從原始結果中獲取分數以供參考
163
+ if data["sparse_rank"] is not None:
164
+ # 從稀疏檢索結果中獲取原始分數
165
+ for sparse_res in sparse_results:
166
+ if self._get_doc_id(sparse_res) == doc_id:
167
+ result["sparse_score"] = sparse_res.get("score", 0.0)
168
+ break
169
+ else:
170
+ result["sparse_score"] = None
171
+
172
+ if data["dense_rank"] is not None:
173
+ # 從密集檢索結果中獲取原始分數
174
+ for dense_res in dense_results:
175
+ if self._get_doc_id(dense_res) == doc_id:
176
+ result["dense_score"] = dense_res.get("score", 0.0)
177
+ break
178
+ else:
179
+ result["dense_score"] = None
180
+
181
+ rrf_results.append(result)
182
+
183
+ # 按 RRF 分數從高到低排序
184
+ rrf_results.sort(key=lambda x: x["rrf_score"], reverse=True)
185
+
186
+ return rrf_results
187
+
188
+ def _apply_weighted_sum(
189
+ self,
190
+ sparse_results: List[Dict],
191
+ dense_results: List[Dict]
192
+ ) -> List[Dict]:
193
+ """
194
+ 應用加權求和(Weighted Sum)方法
195
+
196
+ 此方法需要:
197
+ 1. 正規化兩組分數到相同範圍
198
+ 2. 根據權重進行加權求和
199
+
200
+ Args:
201
+ sparse_results: 稀疏檢索結果列表
202
+ dense_results: 密集檢索結果列表
203
+
204
+ Returns:
205
+ 融合後的結果列表,按混合分數排序
206
+ """
207
+ # 正規化兩組分數
208
+ normalized_sparse = self._normalize_scores(sparse_results)
209
+ normalized_dense = self._normalize_scores(dense_results)
210
+
211
+ # 結合分數
212
+ doc_to_scores = {}
213
+
214
+ # 處理稀疏檢索結果
215
+ for res in normalized_sparse:
216
+ doc_id = self._get_doc_id(res)
217
+ if doc_id not in doc_to_scores:
218
+ doc_to_scores[doc_id] = {"doc": res, "sparse": 0.0, "dense": 0.0}
219
+ doc_to_scores[doc_id]["sparse"] = res["score"]
220
+
221
+ # 處理密集檢索結果
222
+ for res in normalized_dense:
223
+ doc_id = self._get_doc_id(res)
224
+ if doc_id not in doc_to_scores:
225
+ doc_to_scores[doc_id] = {"doc": res, "sparse": 0.0, "dense": 0.0}
226
+ doc_to_scores[doc_id]["dense"] = res["score"]
227
+
228
+ # 計算混合分數並排序
229
+ hybrid_results = []
230
+ for doc_id, scores in doc_to_scores.items():
231
+ hybrid_score = (
232
+ self.sparse_weight * scores["sparse"] +
233
+ self.dense_weight * scores["dense"]
234
+ )
235
+
236
+ result = scores["doc"].copy()
237
+ result["hybrid_score"] = hybrid_score
238
+ result["sparse_score"] = scores["sparse"]
239
+ result["dense_score"] = scores["dense"]
240
+ hybrid_results.append(result)
241
+
242
+ # 按混合分數從高到低排序
243
+ hybrid_results.sort(key=lambda x: x["hybrid_score"], reverse=True)
244
+
245
+ return hybrid_results
246
+
247
+ def retrieve(
248
+ self,
249
+ query: str,
250
+ top_k: int = 5,
251
+ metadata_filter: Optional[Dict] = None
252
+ ) -> List[Dict]:
253
+ """
254
+ 執行混合搜尋,支援根據 metadata 進行過濾
255
+
256
+ Args:
257
+ query: 查詢文字
258
+ top_k: 返回前 k 個結果
259
+ metadata_filter: 可選的 metadata 過濾條件字典。
260
+ 例如: {"arxiv_id": "1234.5678"} 只檢索特定論文的 chunks
261
+ 或 {"title": "Machine Learning"} 只檢索特定標題的論文
262
+ 支援多個條件,所有條件必須同時滿足(AND 邏輯)
263
+ 此過濾條件會傳遞給底層的稀疏和密集檢索器
264
+
265
+ Returns:
266
+ 相關文檔列表,每個包含 "content", "metadata", "hybrid_score"
267
+ 結果會根據 metadata_filter 進行過濾
268
+
269
+ 根據 fusion_method 的不同,返回的結果會包含不同的分數欄位:
270
+ - RRF 方法:包含 "rrf_score", "sparse_rank", "dense_rank"
271
+ - Weighted Sum 方法:包含 "sparse_score", "dense_score"
272
+ """
273
+ # 1. 從兩個檢索器獲取結果(請求更多結果以確保覆蓋率)
274
+ # 將 metadata_filter 傳遞給底層檢索器
275
+ sparse_results = self.sparse_retriever.retrieve(
276
+ query,
277
+ top_k=top_k * 2,
278
+ metadata_filter=metadata_filter
279
+ )
280
+ dense_results = self.dense_retriever.retrieve(
281
+ query,
282
+ top_k=top_k * 2,
283
+ metadata_filter=metadata_filter
284
+ )
285
+
286
+ # 2. 根據選擇的融合方法進行結果融合
287
+ if self.fusion_method == "rrf":
288
+ # 使用 RRF(倒數排名融合)方法
289
+ # RRF 不需要分數正規化,直接基於排名進行融合
290
+ hybrid_results = self._apply_rrf(sparse_results, dense_results)
291
+ else:
292
+ # 使用加權求和方法
293
+ # 需要先正規化分數,然後根據權重進行加權求和
294
+ hybrid_results = self._apply_weighted_sum(sparse_results, dense_results)
295
+
296
+ # 3. 返回前 top_k 個結果
297
+ return hybrid_results[:top_k]
298
+
src/retrievers/reranker.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 重排序模組:使用 Cross-Encoder 進行精準重排
3
+ """
4
+ from typing import List, Dict, Optional, Tuple
5
+ from sentence_transformers import CrossEncoder
6
+ import time
7
+ import logging
8
+
9
+ # 嘗試導入 torch 來檢測可用的設備
10
+ try:
11
+ import torch
12
+ TORCH_AVAILABLE = True
13
+ except ImportError:
14
+ TORCH_AVAILABLE = False
15
+
16
+ # 配置日志
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ def get_device() -> str:
22
+ """
23
+ 自動檢測並返回最佳可用的設備
24
+
25
+ Returns:
26
+ 設備名稱: 'mps' (macOS GPU), 'cuda' (NVIDIA GPU), 或 'cpu'
27
+ """
28
+ if not TORCH_AVAILABLE:
29
+ return 'cpu'
30
+
31
+ # 優先順序: MPS (macOS) > CUDA (NVIDIA) > CPU
32
+ if torch.backends.mps.is_available():
33
+ return 'mps'
34
+ elif torch.cuda.is_available():
35
+ return 'cuda'
36
+ else:
37
+ return 'cpu'
38
+
39
+
40
+ class Reranker:
41
+ """重排序組件:使用 Cross-Encoder 進行精準重排"""
42
+
43
+ def __init__(
44
+ self,
45
+ model_name: str = "BAAI/bge-reranker-base",
46
+ device: str = None,
47
+ max_length: int = 512,
48
+ batch_size: int = 32,
49
+ enable_cache: bool = True
50
+ ):
51
+ """
52
+ 初始化 Cross-Encoder 模型
53
+
54
+ Args:
55
+ model_name: Cross-Encoder 模型名稱
56
+ device: 設備名稱 ('cuda', 'cpu', 'mps')
57
+ max_length: 最大 token 長度(模型限制)
58
+ batch_size: 批處理大小,用於優化內存使用
59
+ enable_cache: 是否啟用模型緩存
60
+ """
61
+ try:
62
+ # 自動檢測設備(如果未指定)
63
+ if device is None:
64
+ device = get_device()
65
+
66
+ device_name_map = {
67
+ 'mps': 'MPS (macOS GPU)',
68
+ 'cuda': 'CUDA (NVIDIA GPU)',
69
+ 'cpu': 'CPU'
70
+ }
71
+ device_display = device_name_map.get(device, device)
72
+
73
+ self.model = CrossEncoder(
74
+ model_name,
75
+ device=device,
76
+ max_length=max_length
77
+ )
78
+ self.max_length = max_length
79
+ self.batch_size = batch_size
80
+ self.model_name = model_name
81
+ logger.info(f"✅ 重排模型 {model_name} 已載入 (device: {device_display})")
82
+ except Exception as e:
83
+ logger.error(f"❌ 模型載入失敗: {e}")
84
+ raise
85
+
86
+ def _truncate_text(self, text: str, max_chars: int = 2000) -> str:
87
+ """
88
+ 截斷過長的文本(粗略估計,避免超過 token 限制)
89
+
90
+ Args:
91
+ text: 原始文本
92
+ max_chars: 最大字符數(保守估計,約 500 tokens)
93
+
94
+ Returns:
95
+ 截斷後的文本
96
+ """
97
+ if len(text) <= max_chars:
98
+ return text
99
+ # 截斷並添加省略號
100
+ return text[:max_chars - 3] + "..."
101
+
102
+ def _prepare_pairs(
103
+ self,
104
+ query: str,
105
+ documents: List[Dict]
106
+ ) -> List[Tuple[str, str]]:
107
+ """
108
+ 準備 (query, document) 配對,處理文本長度
109
+
110
+ Args:
111
+ query: 查詢文本
112
+ documents: 文檔列表
113
+
114
+ Returns:
115
+ (query, content) 配對列表
116
+ """
117
+ pairs = []
118
+ truncated_indices = [] # 記錄哪些文檔被截斷了
119
+
120
+ # 粗略估計:每個字符約 0.25 tokens,為 query 預留空間
121
+ max_doc_chars = int((self.max_length * 0.7) - len(query))
122
+
123
+ for i, doc in enumerate(documents):
124
+ content = doc.get("content", "")
125
+ original_length = len(content)
126
+
127
+ # 如果內容過長,進行截斷
128
+ if len(content) > max_doc_chars:
129
+ content = self._truncate_text(content, max_doc_chars)
130
+ truncated_indices.append(i)
131
+
132
+ pairs.append([query, content])
133
+
134
+ if truncated_indices:
135
+ logger.warning(
136
+ f"⚠️ 有 {len(truncated_indices)} 個文檔因過長被截斷 "
137
+ f"(最大長度: {max_doc_chars} 字符)"
138
+ )
139
+
140
+ return pairs
141
+
142
+ def rerank(
143
+ self,
144
+ query: str,
145
+ documents: List[Dict],
146
+ top_k: int = 5,
147
+ preserve_original_scores: bool = True
148
+ ) -> List[Dict]:
149
+ """
150
+ 執行精準重排
151
+
152
+ Args:
153
+ query: 查詢文本
154
+ documents: 文檔列表,每個應包含 "content" 和可選的 "hybrid_score"
155
+ top_k: 返回前 k 個結果
156
+ preserve_original_scores: 是否保留原始分數(hybrid_score)
157
+
158
+ Returns:
159
+ 重排後的文檔列表,按 rerank_score 降序排列
160
+ """
161
+ if not documents:
162
+ logger.warning("⚠️ 文檔列表為空,返回空結果")
163
+ return []
164
+
165
+ if not query or not query.strip():
166
+ logger.warning("⚠️ 查詢為空,返回原始文檔順序")
167
+ return documents[:top_k]
168
+
169
+ start_time = time.time()
170
+ logger.info(f"🔄 開始重排 {len(documents)} 個文檔...")
171
+
172
+ try:
173
+ # 1. 準備配對
174
+ pairs = self._prepare_pairs(query, documents)
175
+
176
+ # 2. 批處理計算分數(優化內存使用)
177
+ scores = []
178
+ for i in range(0, len(pairs), self.batch_size):
179
+ batch_pairs = pairs[i:i + self.batch_size]
180
+ batch_scores = self.model.predict(batch_pairs)
181
+ scores.extend(batch_scores.tolist() if hasattr(batch_scores, 'tolist') else batch_scores)
182
+
183
+ # 3. 更新文檔分數
184
+ for i, doc in enumerate(documents):
185
+ doc = doc.copy() # 避免修改原始文檔
186
+ doc["rerank_score"] = float(scores[i])
187
+
188
+ # 保留原始分數供參考
189
+ if preserve_original_scores:
190
+ if "hybrid_score" not in doc:
191
+ # 如果沒有 hybrid_score,嘗試使用其他分數
192
+ doc["original_score"] = doc.get("score", 0.0)
193
+
194
+ documents[i] = doc
195
+
196
+ # 4. 根據 rerank_score 重新排序
197
+ reranked_docs = sorted(
198
+ documents,
199
+ key=lambda x: x.get("rerank_score", float('-inf')),
200
+ reverse=True
201
+ )
202
+
203
+ # 5. 統計資訊
204
+ elapsed_time = time.time() - start_time
205
+ avg_score = sum(scores) / len(scores) if scores else 0.0
206
+ max_score = max(scores) if scores else 0.0
207
+ min_score = min(scores) if scores else 0.0
208
+
209
+ logger.info(
210
+ f"✅ 重排完成 (耗時: {elapsed_time:.2f}s, "
211
+ f"平均分數: {avg_score:.4f}, "
212
+ f"範圍: [{min_score:.4f}, {max_score:.4f}])"
213
+ )
214
+
215
+ return reranked_docs[:top_k]
216
+
217
+ except Exception as e:
218
+ logger.error(f"❌ 重排過程出錯: {e}")
219
+ # 降級策略:返回原始順序的前 top_k 個
220
+ logger.warning("⚠️ 使用降級策略:返回原始順序")
221
+ return documents[:top_k]
222
+
223
+
224
+ class RAGPipeline:
225
+ """協調管線:管理完整的 RAG 流程(召回 + 重排)"""
226
+
227
+ def __init__(
228
+ self,
229
+ hybrid_search,
230
+ reranker,
231
+ recall_k: int = 25,
232
+ adaptive_recall: bool = True,
233
+ min_recall_k: int = 10,
234
+ max_recall_k: int = 50
235
+ ):
236
+ """
237
+ 初始化 RAG 管線
238
+
239
+ Args:
240
+ hybrid_search: HybridSearch 實例
241
+ reranker: Reranker 實例
242
+ recall_k: 第一階段召回的數量(預設值)
243
+ adaptive_recall: 是否根據查詢動態調整 recall_k
244
+ min_recall_k: 最小召回數量
245
+ max_recall_k: 最大召回數量
246
+ """
247
+ self.hybrid_search = hybrid_search
248
+ self.reranker = reranker
249
+ self.base_recall_k = recall_k
250
+ self.adaptive_recall = adaptive_recall
251
+ self.min_recall_k = min_recall_k
252
+ self.max_recall_k = max_recall_k
253
+
254
+ # 性能統計
255
+ self.stats = {
256
+ "total_queries": 0,
257
+ "avg_recall_time": 0.0,
258
+ "avg_rerank_time": 0.0,
259
+ "avg_total_time": 0.0
260
+ }
261
+
262
+ def _calculate_adaptive_recall_k(self, query: str) -> int:
263
+ """
264
+ 根據查詢複雜度動態計算 recall_k
265
+
266
+ Args:
267
+ query: 查詢文本
268
+
269
+ Returns:
270
+ 調整後的 recall_k
271
+ """
272
+ if not self.adaptive_recall:
273
+ return self.base_recall_k
274
+
275
+ # 簡單啟發式:根據查詢長度和關鍵詞數量調整
276
+ query_length = len(query.split())
277
+ keyword_count = len(set(query.lower().split()))
278
+
279
+ # 複雜查詢需要更多候選
280
+ if query_length > 10 or keyword_count > 5:
281
+ recall_k = min(self.base_recall_k * 2, self.max_recall_k)
282
+ elif query_length < 3:
283
+ recall_k = max(self.base_recall_k // 2, self.min_recall_k)
284
+ else:
285
+ recall_k = self.base_recall_k
286
+
287
+ return recall_k
288
+
289
+ def query(
290
+ self,
291
+ text: str,
292
+ top_k: int = 5,
293
+ metadata_filter: Optional[Dict] = None,
294
+ enable_rerank: bool = True,
295
+ return_stats: bool = False
296
+ ) -> List[Dict]:
297
+ """
298
+ 執行完整的搜尋流程
299
+
300
+ Args:
301
+ text: 查詢文本
302
+ top_k: 最終返回的結果數量
303
+ metadata_filter: 可選的 metadata 過濾條件
304
+ enable_rerank: 是否啟用重排序(可選,用於性能測試)
305
+ return_stats: 是否返回���能統計資訊
306
+
307
+ Returns:
308
+ 相關文檔列表,如果 return_stats=True,則返回 (results, stats) 元組
309
+ """
310
+ if not text or not text.strip():
311
+ logger.warning("⚠️ 查詢為空")
312
+ return []
313
+
314
+ total_start = time.time()
315
+ self.stats["total_queries"] += 1
316
+
317
+ # 動態計算 recall_k
318
+ recall_k = self._calculate_adaptive_recall_k(text)
319
+ logger.info(
320
+ f"🔍 搜尋中: '{text[:50]}...' "
321
+ f"(召回階段: {recall_k} 筆, 最終返回: {top_k} 筆)"
322
+ )
323
+
324
+ try:
325
+ # 第一階段:混合搜尋(召回階段)
326
+ recall_start = time.time()
327
+ initial_results = self.hybrid_search.retrieve(
328
+ query=text,
329
+ top_k=recall_k,
330
+ metadata_filter=metadata_filter
331
+ )
332
+ recall_time = time.time() - recall_start
333
+
334
+ if not initial_results:
335
+ logger.warning("⚠️ 召回階段未找到任何結果")
336
+ return []
337
+
338
+ logger.info(
339
+ f"✅ 召回階段完成: 找到 {len(initial_results)} 個候選 "
340
+ f"(耗時: {recall_time:.2f}s)"
341
+ )
342
+
343
+ # 第二階段:重排序(精篩階段)
344
+ if enable_rerank and len(initial_results) > top_k:
345
+ rerank_start = time.time()
346
+ final_results = self.reranker.rerank(
347
+ query=text,
348
+ documents=initial_results,
349
+ top_k=top_k
350
+ )
351
+ rerank_time = time.time() - rerank_start
352
+
353
+ logger.info(
354
+ f"✅ 重排階段完成: 從 {len(initial_results)} 個候選中選出 "
355
+ f"{len(final_results)} 個結果 (耗時: {rerank_time:.2f}s)"
356
+ )
357
+ else:
358
+ # 跳過重排序(用於性能測試或候選數較少時)
359
+ final_results = initial_results[:top_k]
360
+ rerank_time = 0.0
361
+ logger.info("⏭️ 跳過重排序階段(候選數不足或已禁用)")
362
+
363
+ # 更新統計資訊
364
+ total_time = time.time() - total_start
365
+ self._update_stats(recall_time, rerank_time, total_time)
366
+
367
+ # 添加性能資訊到結果(可選)
368
+ if return_stats:
369
+ stats = {
370
+ "recall_time": recall_time,
371
+ "rerank_time": rerank_time,
372
+ "total_time": total_time,
373
+ "recall_k": recall_k,
374
+ "candidates_found": len(initial_results),
375
+ "final_results": len(final_results)
376
+ }
377
+ return final_results, stats
378
+
379
+ return final_results
380
+
381
+ except Exception as e:
382
+ logger.error(f"❌ 查詢過程出錯: {e}")
383
+ # 降級策略:嘗試只使用召回階段
384
+ try:
385
+ logger.warning("⚠️ 嘗試降級策略:僅使用召回結果")
386
+ return self.hybrid_search.retrieve(text, top_k=top_k, metadata_filter=metadata_filter)
387
+ except Exception as e2:
388
+ logger.error(f"❌ 降級策略也失敗: {e2}")
389
+ return []
390
+
391
+ def _update_stats(self, recall_time: float, rerank_time: float, total_time: float):
392
+ """更新性能統計資訊"""
393
+ n = self.stats["total_queries"]
394
+ self.stats["avg_recall_time"] = (
395
+ (self.stats["avg_recall_time"] * (n - 1) + recall_time) / n
396
+ )
397
+ self.stats["avg_rerank_time"] = (
398
+ (self.stats["avg_rerank_time"] * (n - 1) + rerank_time) / n
399
+ )
400
+ self.stats["avg_total_time"] = (
401
+ (self.stats["avg_total_time"] * (n - 1) + total_time) / n
402
+ )
403
+
404
+ def get_stats(self) -> Dict:
405
+ """獲取性能統計資訊"""
406
+ return self.stats.copy()
407
+
408
+ def reset_stats(self):
409
+ """重置統計資訊"""
410
+ self.stats = {
411
+ "total_queries": 0,
412
+ "avg_recall_time": 0.0,
413
+ "avg_rerank_time": 0.0,
414
+ "avg_total_time": 0.0
415
+ }
416
+
417
+ def format_results_for_llm(
418
+ self,
419
+ results: List[Dict],
420
+ format_style: str = "detailed"
421
+ ) -> str:
422
+ """
423
+ 格式化檢索結果供 LLM 使用(需要導入 PromptFormatter)
424
+
425
+ Args:
426
+ results: 檢索結果列表
427
+ format_style: 格式風格 ("detailed", "simple", "minimal")
428
+
429
+ Returns:
430
+ 格式化後的上下文字符串
431
+ """
432
+ try:
433
+ from ..prompt_formatter import PromptFormatter
434
+ formatter = PromptFormatter(format_style=format_style)
435
+ return formatter.format_context(results)
436
+ except ImportError:
437
+ # 如果無法導入,使用簡單格式
438
+ formatted_parts = []
439
+ for i, result in enumerate(results, 1):
440
+ metadata = result.get("metadata", {})
441
+ content = result.get("content", "")
442
+ arxiv_id = metadata.get('arxiv_id', 'N/A')
443
+ title = metadata.get('title', 'N/A')
444
+ formatted_parts.append(
445
+ f"[來源 {i}: {title} (arXiv:{arxiv_id})]\n{content}\n"
446
+ )
447
+ return "\n" + "="*60 + "\n".join(formatted_parts)
448
+
src/retrievers/vector_retriever.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 向量檢索器模組:使用 embedding 和向量資料庫進行語義檢索
3
+
4
+ 支援兩種初始化方式:
5
+ 1. 自動初始化 embeddings(預設):根據參數創建新的 embedding 模型
6
+ 2. 使用外部 embeddings:接收已初始化的 embedding 模型(可與 DocumentProcessor 共用)
7
+ """
8
+ from typing import List, Dict, Optional, Any
9
+ from langchain_community.vectorstores import Chroma
10
+ from langchain_core.documents import Document
11
+ import os
12
+ from .base import BaseRetriever
13
+
14
+ # 嘗試導入 HuggingFaceEmbeddings(免費模型)
15
+ try:
16
+ from langchain_community.embeddings import HuggingFaceEmbeddings
17
+ except ImportError:
18
+ try:
19
+ from langchain_huggingface import HuggingFaceEmbeddings
20
+ except ImportError:
21
+ raise ImportError("需要安裝 langchain-community 或 langchain-huggingface 才能使用 Hugging Face embeddings")
22
+
23
+ # 導入 torch 來檢測可用的設備
24
+ try:
25
+ import torch
26
+ TORCH_AVAILABLE = True
27
+ except ImportError:
28
+ TORCH_AVAILABLE = False
29
+
30
+
31
+ def get_device() -> str:
32
+ """
33
+ 自動檢測並返回最佳可用的設備
34
+
35
+ Returns:
36
+ 設備名稱: 'mps' (macOS GPU), 'cuda' (NVIDIA GPU), 或 'cpu'
37
+ """
38
+ if not TORCH_AVAILABLE:
39
+ return 'cpu'
40
+
41
+ # 優先順序: MPS (macOS) > CUDA (NVIDIA) > CPU
42
+ if torch.backends.mps.is_available():
43
+ return 'mps'
44
+ elif torch.cuda.is_available():
45
+ return 'cuda'
46
+ else:
47
+ return 'cpu'
48
+
49
+
50
+ class VectorRetriever(BaseRetriever):
51
+ """使用向量檢索進行語義搜尋"""
52
+
53
+ def __init__(
54
+ self,
55
+ documents: List[Dict],
56
+ embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2",
57
+ persist_directory: Optional[str] = "./chroma_db",
58
+ hf_cache_dir: Optional[str] = None,
59
+ device: Optional[str] = None,
60
+ embeddings: Optional[Any] = None # 可選:外部傳入的 embedding 模型(優先使用)
61
+ ):
62
+ """
63
+ 初始化向量檢索器(使用 Hugging Face embeddings)
64
+
65
+ Args:
66
+ documents: 文檔列表,每個文檔包含 "content" 和 "metadata"
67
+ embedding_model: Hugging Face embedding 模型名稱(預設: "sentence-transformers/all-MiniLM-L6-v2")
68
+ 僅在 embeddings=None 時使用
69
+ persist_directory: Chroma 資料庫持久化目錄
70
+ hf_cache_dir: Hugging Face 模型緩存目錄(例如外接硬碟路徑)
71
+ 如果為 None,則使用環境變數 HF_HOME 或預設位置 ~/.cache/huggingface/
72
+ 僅在 embeddings=None 時使用
73
+ device: 設備名稱 ('mps', 'cuda', 'cpu'),如果為 None 則自動檢測最佳設備
74
+ 僅在 embeddings=None 時使用
75
+ embeddings: 可選的外部 embedding 模型物件
76
+ 如果提供,將優先使用此模型,忽略其他參數(embedding_model, hf_cache_dir, device)
77
+ 這允許與 DocumentProcessor 共用同一個 embedding 模型實例
78
+ 優點:
79
+ - 節省內存(只加載一次模型)
80
+ - 節省時間(避免重複初始化)
81
+ - 確保一致性(分塊和檢索使用相同的模型)
82
+ """
83
+ # 優先使用傳入的共用模型
84
+ if embeddings is not None:
85
+ self.embeddings = embeddings
86
+ print("✓ 使用外部傳入的 embeddings 模型(與 DocumentProcessor 共用)")
87
+ else:
88
+ # 若無傳入,則執行原有的初始化邏輯
89
+ print(f"使用 Hugging Face embedding 模型: {embedding_model}")
90
+
91
+ # 設置 Hugging Face 緩存目錄
92
+ if hf_cache_dir:
93
+ # 如果指定了緩存目錄,設置環境變數
94
+ os.environ['HF_HOME'] = hf_cache_dir
95
+ os.environ['TRANSFORMERS_CACHE'] = hf_cache_dir
96
+ print(f"模型將存儲在: {hf_cache_dir}")
97
+ else:
98
+ # 檢查是否已經設置了環境變數
99
+ default_cache = os.path.expanduser("~/.cache/huggingface")
100
+ current_cache = os.getenv('HF_HOME', default_cache)
101
+ print(f"模型緩存位置: {current_cache}")
102
+ print("提示: 可以通過設置 hf_cache_dir 參數或環境變數 HF_HOME 來指定外接硬碟路徑")
103
+
104
+ # 自動檢測或使用指定的設備
105
+ if device is None:
106
+ device = get_device()
107
+
108
+ device_name_map = {
109
+ 'mps': 'MPS (macOS GPU)',
110
+ 'cuda': 'CUDA (NVIDIA GPU)',
111
+ 'cpu': 'CPU'
112
+ }
113
+ print(f"使用設備: {device_name_map.get(device, device)}")
114
+ print("首次使用時會下載模型,請稍候...")
115
+
116
+ # 構建 model_kwargs,包含緩存目錄和設備
117
+ model_kwargs = {'device': device}
118
+ if hf_cache_dir:
119
+ model_kwargs['cache_dir'] = hf_cache_dir
120
+
121
+ self.embeddings = HuggingFaceEmbeddings(
122
+ model_name=embedding_model,
123
+ model_kwargs=model_kwargs,
124
+ encode_kwargs={'normalize_embeddings': True} # 正規化 embeddings 以提升效果
125
+ )
126
+
127
+ # 將文檔轉換為 LangChain Document 格式
128
+ # 需要將 metadata 中的列表轉換為字串,因為 ChromaDB 不接受列表類型
129
+ def sanitize_metadata(metadata: Dict) -> Dict:
130
+ """將 metadata 中的列表轉換為字串,以符合 ChromaDB 的要求"""
131
+ sanitized = {}
132
+ for key, value in metadata.items():
133
+ if isinstance(value, list):
134
+ # 將列表轉換為逗號分隔的字串
135
+ sanitized[key] = ", ".join(str(v) for v in value)
136
+ elif isinstance(value, (dict, set)):
137
+ # 將字典或集合轉換為字串
138
+ sanitized[key] = str(value)
139
+ else:
140
+ # 其他類型(str, int, float, bool, None)直接保留
141
+ sanitized[key] = value
142
+ return sanitized
143
+
144
+ langchain_docs = [
145
+ Document(
146
+ page_content=doc["content"],
147
+ metadata=sanitize_metadata(doc["metadata"])
148
+ )
149
+ for doc in documents
150
+ ]
151
+
152
+ # 創建向量資料庫
153
+ self.vectorstore = Chroma.from_documents(
154
+ documents=langchain_docs,
155
+ embedding=self.embeddings,
156
+ persist_directory=persist_directory
157
+ )
158
+
159
+ # 創建 retriever
160
+ self.retriever = self.vectorstore.as_retriever()
161
+
162
+ def retrieve(
163
+ self,
164
+ query: str,
165
+ top_k: int = 5,
166
+ metadata_filter: Optional[Dict] = None
167
+ ) -> List[Dict]:
168
+ """
169
+ 檢索相關文檔,並返回標準化的相似度分數(越高越好)。
170
+ 支援根據 metadata 進行過濾。
171
+
172
+ Args:
173
+ query: 查詢文字
174
+ top_k: 返回前 k 個結果
175
+ metadata_filter: 可選的 metadata 過濾條件字典。
176
+ 例如: {"arxiv_id": "1234.5678"} 只檢索特定論文的 chunks
177
+ 或 {"title": "Machine Learning"} 只檢索特定標題的論文
178
+ 支援多個條件,所有條件必須同時滿足(AND 邏輯)
179
+ 注意:ChromaDB 的 where 條件支援精確匹配,不支援部分匹配
180
+
181
+ Returns:
182
+ 相關文檔列表,每個包含 "content", "metadata", 和 "score"
183
+ 結果會根據 metadata_filter 進行過濾
184
+ """
185
+ # 構建過濾條件
186
+ # 如果提供了 metadata_filter,先獲取更多結果,然後在 Python 中進行過濾
187
+ # 這是因為 LangChain ChromaDB 的 similarity_search_with_score 方法
188
+ # 對 filter 參數的支援可能因版本而異
189
+ if metadata_filter:
190
+ # 獲取更多結果以確保有足夠的候選進行過濾
191
+ results_with_scores = self.vectorstore.similarity_search_with_score(
192
+ query,
193
+ k=top_k * 10 # 獲取更多結果
194
+ )
195
+
196
+ # 在 Python 中進行過濾
197
+ filtered_results = []
198
+ for doc, distance_score in results_with_scores:
199
+ metadata = doc.metadata
200
+ matches = True
201
+
202
+ for key, value in metadata_filter.items():
203
+ doc_value = metadata.get(key)
204
+
205
+ # 檢查是否匹配
206
+ if isinstance(value, dict):
207
+ # 支援運算符格式(例如 {"$eq": "value"})
208
+ if "$eq" in value:
209
+ if doc_value != value["$eq"]:
210
+ matches = False
211
+ break
212
+ else:
213
+ # 其他運算符可以在此擴展
214
+ matches = False
215
+ break
216
+ elif isinstance(value, str) and isinstance(doc_value, str):
217
+ # 字串匹配:支援部分匹配(包含)
218
+ if value.lower() not in doc_value.lower():
219
+ matches = False
220
+ break
221
+ else:
222
+ # 其他類型使用精確匹配
223
+ if doc_value != value:
224
+ matches = False
225
+ break
226
+
227
+ if matches:
228
+ filtered_results.append((doc, distance_score))
229
+
230
+ # 只保留前 top_k 個結果
231
+ results_with_scores = filtered_results[:top_k]
232
+ else:
233
+ # 沒有過濾條件,直接獲取結果
234
+ results_with_scores = self.vectorstore.similarity_search_with_score(
235
+ query,
236
+ k=top_k
237
+ )
238
+
239
+ # 構建結果並轉換分數
240
+ results = []
241
+ for doc, distance_score in results_with_scores:
242
+ # 因為 embedding 已正規化,L2 距離的平方為 2 - 2 * cos_sim
243
+ # -> cos_sim = 1 - (distance^2 / 2)
244
+ # 分數範圍在 [0, 1] 之間,越高越相似
245
+ similarity_score = 1 - (distance_score**2 / 2)
246
+
247
+ results.append({
248
+ "content": doc.page_content,
249
+ "metadata": doc.metadata,
250
+ "score": float(similarity_score),
251
+ })
252
+
253
+ return results
254
+
src/step_back_rag.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Step-back Prompting 雙軌 RAG:結合具體事實與抽象原理
3
+ 使用 Step-back Prompting 技術,同時檢索具體事實和抽象原理,提升回答質量
4
+ """
5
+ from typing import List, Dict, Optional
6
+ from .retrievers.reranker import RAGPipeline
7
+ from .retrievers.vector_retriever import VectorRetriever
8
+ from .prompt_formatter import PromptFormatter
9
+ from .llm_integration import OllamaLLM
10
+ import time
11
+ import logging
12
+ import hashlib
13
+ from concurrent.futures import ThreadPoolExecutor, as_completed
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class StepBackRAG:
19
+ """使用 Step-back Prompting 的雙軌 RAG 系統"""
20
+
21
+ def __init__(
22
+ self,
23
+ rag_pipeline: RAGPipeline,
24
+ vector_retriever: VectorRetriever,
25
+ llm: OllamaLLM,
26
+ step_back_temperature: float = 0.3, # 生成抽象問題時使用較低溫度
27
+ answer_temperature: float = 0.7,
28
+ enable_parallel: bool = True
29
+ ):
30
+ """
31
+ 初始化 Step-back RAG
32
+
33
+ Args:
34
+ rag_pipeline: RAG 管線實例(用於最終答案生成)
35
+ vector_retriever: 向量檢索器
36
+ llm: LLM 實例
37
+ step_back_temperature: 生成抽象問題的溫度(較低,更穩定)
38
+ answer_temperature: 生成答案的溫度
39
+ enable_parallel: 是否並行執行雙軌檢索
40
+ """
41
+ self.rag_pipeline = rag_pipeline
42
+ self.vector_retriever = vector_retriever
43
+ self.llm = llm
44
+ self.step_back_temperature = step_back_temperature
45
+ self.answer_temperature = answer_temperature
46
+ self.enable_parallel = enable_parallel
47
+
48
+ def _generate_step_back_question(self, question: str) -> str:
49
+ """
50
+ 生成 Step-back 抽象問題
51
+
52
+ Args:
53
+ question: 原始具體問題
54
+
55
+ Returns:
56
+ 抽象問題
57
+ """
58
+ is_chinese = PromptFormatter.detect_language(question) == "zh"
59
+
60
+ if is_chinese:
61
+ prompt = f"""你是一個資深專家。請將以下具體問題轉換為一個更抽象、更基礎的原理性問題。
62
+ 這個抽象問題應該幫助理解該領域的基礎概念和原理,而不是直接回答具體問題。
63
+
64
+ 具體問題: {question}
65
+
66
+ 請生成一個抽象問題,用於檢索相關的原理和背景知識:
67
+ """
68
+ else:
69
+ prompt = f"""You are a senior expert. Please convert the following specific question into a more abstract, fundamental question about principles and concepts.
70
+ This abstract question should help understand the basic concepts and principles in this field, rather than directly answering the specific question.
71
+
72
+ Specific question: {question}
73
+
74
+ Please generate an abstract question for retrieving relevant principles and background knowledge:
75
+ """
76
+
77
+ try:
78
+ abstract_question = self.llm.generate(
79
+ prompt=prompt,
80
+ temperature=self.step_back_temperature,
81
+ max_tokens=200
82
+ )
83
+
84
+ abstract_question = abstract_question.strip()
85
+
86
+ if not abstract_question:
87
+ logger.warning("⚠️ 生成的抽象問題為空,使用原始問題")
88
+ return question
89
+
90
+ logger.info(f"✅ 生成抽象問題: '{abstract_question}'")
91
+ return abstract_question
92
+
93
+ except Exception as e:
94
+ logger.error(f"⚠️ 生成抽象問題時出錯: {e}")
95
+ return question
96
+
97
+ def _get_doc_id(self, doc: Dict) -> str:
98
+ """
99
+ 生成文檔的唯一標識符
100
+
101
+ Args:
102
+ doc: 文檔字典
103
+
104
+ Returns:
105
+ 唯一 ID
106
+ """
107
+ metadata = doc.get("metadata", {})
108
+ content = doc.get("content", "")
109
+
110
+ if "arxiv_id" in metadata and "chunk_index" in metadata:
111
+ return f"{metadata['arxiv_id']}_{metadata['chunk_index']}"
112
+ elif "file_path" in metadata and "chunk_index" in metadata:
113
+ return f"{metadata['file_path']}_{metadata['chunk_index']}"
114
+ else:
115
+ content_hash = hashlib.md5(content.encode()).hexdigest()[:16]
116
+ return f"doc_{content_hash}"
117
+
118
+ def _retrieve_direct(self, question: str, top_k: int, metadata_filter: Optional[Dict] = None) -> List[Dict]:
119
+ """直接檢索原始問題(具體事實)"""
120
+ return self.vector_retriever.retrieve(
121
+ query=question,
122
+ top_k=top_k,
123
+ metadata_filter=metadata_filter
124
+ )
125
+
126
+ def _retrieve_step_back(self, question: str, top_k: int, metadata_filter: Optional[Dict] = None) -> tuple:
127
+ """Step-back 檢索(抽象原理)"""
128
+ abstract_question = self._generate_step_back_question(question)
129
+ results = self.vector_retriever.retrieve(
130
+ query=abstract_question,
131
+ top_k=top_k,
132
+ metadata_filter=metadata_filter
133
+ )
134
+ return results, abstract_question
135
+
136
+ def query(
137
+ self,
138
+ question: str,
139
+ top_k: int = 5,
140
+ metadata_filter: Optional[Dict] = None,
141
+ return_abstract_question: bool = False
142
+ ) -> Dict:
143
+ """
144
+ 執行雙軌檢索(不生成答案)
145
+
146
+ Args:
147
+ question: 原始問題
148
+ top_k: 每軌返回的結果數量
149
+ metadata_filter: 可選的 metadata 過濾條件
150
+ return_abstract_question: 是否返回抽象問題
151
+
152
+ Returns:
153
+ 包含雙軌檢索結果的字典
154
+ """
155
+ start_time = time.time()
156
+
157
+ if self.enable_parallel:
158
+ # 並行執行雙軌檢索
159
+ logger.info(f"🔄 並行執行雙軌檢索: '{question}'")
160
+ with ThreadPoolExecutor(max_workers=2) as executor:
161
+ direct_future = executor.submit(
162
+ self._retrieve_direct, question, top_k, metadata_filter
163
+ )
164
+ step_back_future = executor.submit(
165
+ self._retrieve_step_back, question, top_k, metadata_filter
166
+ )
167
+
168
+ specific_results = direct_future.result()
169
+ abstract_results, abstract_question = step_back_future.result()
170
+ else:
171
+ # 串行執行
172
+ logger.info(f"🔄 串行執行雙軌檢索: '{question}'")
173
+ specific_results = self._retrieve_direct(question, top_k, metadata_filter)
174
+ abstract_results, abstract_question = self._retrieve_step_back(question, top_k, metadata_filter)
175
+
176
+ elapsed_time = time.time() - start_time
177
+ logger.info(
178
+ f"✅ 雙軌檢索完成(耗時: {elapsed_time:.2f}s)\n"
179
+ f" 具體事實: {len(specific_results)} 個結果\n"
180
+ f" 抽象原理: {len(abstract_results)} 個結果"
181
+ )
182
+
183
+ return {
184
+ "specific_context": specific_results,
185
+ "abstract_context": abstract_results,
186
+ "abstract_question": abstract_question if return_abstract_question else None,
187
+ "question": question,
188
+ "elapsed_time": elapsed_time
189
+ }
190
+
191
+ def generate_answer(
192
+ self,
193
+ question: str,
194
+ formatter: PromptFormatter,
195
+ top_k: int = 5,
196
+ metadata_filter: Optional[Dict] = None,
197
+ document_type: str = "general",
198
+ return_abstract_question: bool = False
199
+ ) -> Dict:
200
+ """
201
+ 完整的 Step-back RAG 流程:雙軌檢索 -> 生成答案
202
+
203
+ Args:
204
+ question: 原始問題
205
+ formatter: Prompt 格式化器
206
+ top_k: 每軌用於生成答案的文檔數量
207
+ metadata_filter: 可選的 metadata 過濾條件
208
+ document_type: 文檔類型 ("paper", "cv", "general")
209
+ return_abstract_question: 是否返回抽象問題
210
+
211
+ Returns:
212
+ 包含檢索結果、生成的答案和統計資訊的字典
213
+ """
214
+ start_time = time.time()
215
+
216
+ # 第一步:雙軌檢索
217
+ retrieval_result = self.query(
218
+ question=question,
219
+ top_k=top_k,
220
+ metadata_filter=metadata_filter,
221
+ return_abstract_question=return_abstract_question
222
+ )
223
+
224
+ specific_results = retrieval_result["specific_context"]
225
+ abstract_results = retrieval_result["abstract_context"]
226
+
227
+ if not specific_results and not abstract_results:
228
+ return {
229
+ **retrieval_result,
230
+ "answer": "抱歉,未找到相關文檔來回答此問題。",
231
+ "formatted_context": None,
232
+ "answer_time": 0.0,
233
+ "total_time": retrieval_result["elapsed_time"]
234
+ }
235
+
236
+ # 第二步:格式化雙軌上下文
237
+ specific_context = formatter.format_context(
238
+ specific_results,
239
+ document_type=document_type
240
+ ) if specific_results else "未找到相關的具體事實資料。"
241
+
242
+ abstract_context = formatter.format_context(
243
+ abstract_results,
244
+ document_type=document_type
245
+ ) if abstract_results else "未找到相關的基礎原理資料。"
246
+
247
+ # 第三步:創建融合提示詞(關鍵步驟)
248
+ is_chinese = PromptFormatter.detect_language(question) == "zh"
249
+
250
+ if is_chinese:
251
+ final_prompt = f"""你是一個資深專家。請結合以下兩類資訊來回答使用者的具體問題。
252
+
253
+ 【基礎原理與背景】
254
+ {abstract_context}
255
+
256
+ 【具體事實資料】
257
+ {specific_context}
258
+
259
+ 使用者問題:{question}
260
+
261
+ 請根據原理推導並結合事實,給出一個專業且具備邏輯的回答:
262
+ """
263
+ else:
264
+ final_prompt = f"""You are a senior expert. Please answer the user's specific question by combining the following two types of information.
265
+
266
+ 【Fundamental Principles and Background】
267
+ {abstract_context}
268
+
269
+ 【Specific Facts and Data】
270
+ {specific_context}
271
+
272
+ User question: {question}
273
+
274
+ Please provide a professional and logical answer based on principles and facts:
275
+ """
276
+
277
+ # 第四步:生成回答
278
+ logger.info("🤖 生成回答中...")
279
+ answer_start = time.time()
280
+ try:
281
+ answer = self.llm.generate(
282
+ prompt=final_prompt,
283
+ temperature=self.answer_temperature,
284
+ max_tokens=2048
285
+ )
286
+ answer_time = time.time() - answer_start
287
+ logger.info(f"✅ 回答生成完成(耗時: {answer_time:.2f}s)")
288
+ except Exception as e:
289
+ logger.error(f"❌ 生成回答時出錯: {e}")
290
+ answer = f"生成回答時出錯: {e}"
291
+ answer_time = time.time() - answer_start
292
+
293
+ total_time = time.time() - start_time
294
+
295
+ return {
296
+ **retrieval_result,
297
+ "answer": answer,
298
+ "formatted_context": {
299
+ "specific": specific_context,
300
+ "abstract": abstract_context
301
+ },
302
+ "answer_time": answer_time,
303
+ "total_time": total_time
304
+ }
305
+
src/subquery_rag.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sub-query Decomposition RAG:將複雜問題拆解成子問題後檢索
3
+ """
4
+ from typing import List, Dict, Optional
5
+ from .retrievers.reranker import RAGPipeline
6
+ from .prompt_formatter import PromptFormatter
7
+ from .llm_integration import OllamaLLM
8
+ import hashlib
9
+ import time
10
+ import logging
11
+ from concurrent.futures import ThreadPoolExecutor, as_completed
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class SubQueryDecompositionRAG:
17
+ """使用子問題拆解的 RAG 系統"""
18
+
19
+ def __init__(
20
+ self,
21
+ rag_pipeline: RAGPipeline,
22
+ llm: OllamaLLM,
23
+ max_sub_queries: int = 3,
24
+ top_k_per_subquery: int = 5,
25
+ enable_parallel: bool = True
26
+ ):
27
+ """
28
+ 初始化 Sub-query Decomposition RAG
29
+
30
+ Args:
31
+ rag_pipeline: 現有的 RAG 管線實例
32
+ llm: LLM 實例(用於生成子問題)
33
+ max_sub_queries: 最多生成的子問題數量
34
+ top_k_per_subquery: 每個子問題檢索的結果數量
35
+ enable_parallel: 是否並行處理子查詢
36
+ """
37
+ self.rag_pipeline = rag_pipeline
38
+ self.llm = llm
39
+ self.max_sub_queries = max_sub_queries
40
+ self.top_k_per_subquery = top_k_per_subquery
41
+ self.enable_parallel = enable_parallel
42
+
43
+ def _generate_sub_queries(self, question: str) -> List[str]:
44
+ """
45
+ 將原始問題拆解成子問題
46
+
47
+ Args:
48
+ question: 原始問題
49
+
50
+ Returns:
51
+ 子問題列表
52
+ """
53
+ # 檢測語言
54
+ is_chinese = PromptFormatter.detect_language(question) == "zh"
55
+
56
+ if is_chinese:
57
+ prompt = f"""你是一個專業助理。請將以下原始問題拆解成最多 {self.max_sub_queries} 個具體的子問題,以便進行資料搜尋。
58
+ 每個子問題應專注於原始問題的一個特定面向。請以換行符號分隔問題。
59
+
60
+ 原始問題: {question}
61
+
62
+ 子問題清單:"""
63
+ else:
64
+ prompt = f"""You are a professional assistant. Please decompose the following original question into at most {self.max_sub_queries} specific sub-questions for information retrieval.
65
+ Each sub-question should focus on a specific aspect of the original question. Please separate questions with newlines.
66
+
67
+ Original question: {question}
68
+
69
+ Sub-question list:"""
70
+
71
+ try:
72
+ response = self.llm.generate(
73
+ prompt=prompt,
74
+ temperature=0.3, # 降低溫度以獲得更穩定的結果
75
+ max_tokens=500
76
+ )
77
+
78
+ # 解析子問題
79
+ sub_queries = [
80
+ q.strip()
81
+ for q in response.strip().split("\n")
82
+ if q.strip() and not q.strip().startswith("#")
83
+ ]
84
+
85
+ # 移除編號前綴(如 "1. ", "1) " 等)
86
+ cleaned_queries = []
87
+ for q in sub_queries:
88
+ # 移除開頭的編號
89
+ q = q.lstrip("0123456789. )")
90
+ q = q.strip()
91
+ if q:
92
+ cleaned_queries.append(q)
93
+
94
+ # 限制數量
95
+ cleaned_queries = cleaned_queries[:self.max_sub_queries]
96
+
97
+ # 如果沒有生成子問題,使用原始問題
98
+ if not cleaned_queries:
99
+ logger.warning("⚠️ 未生成子問題,使用原始問題")
100
+ cleaned_queries = [question]
101
+
102
+ return cleaned_queries
103
+
104
+ except Exception as e:
105
+ logger.error(f"⚠️ 生成子問題時出錯: {e}")
106
+ # 回退到原始問題
107
+ return [question]
108
+
109
+ def _get_doc_id(self, doc: Dict) -> str:
110
+ """
111
+ 生成文檔的唯一標識符
112
+
113
+ Args:
114
+ doc: 文檔字典
115
+
116
+ Returns:
117
+ 唯一 ID
118
+ """
119
+ metadata = doc.get("metadata", {})
120
+ content = doc.get("content", "")
121
+
122
+ # 使用 metadata 中的唯一標識(如果有的話)
123
+ if "arxiv_id" in metadata and "chunk_index" in metadata:
124
+ return f"{metadata['arxiv_id']}_{metadata['chunk_index']}"
125
+ elif "file_path" in metadata and "chunk_index" in metadata:
126
+ return f"{metadata['file_path']}_{metadata['chunk_index']}"
127
+ else:
128
+ # 回退到內容的 hash
129
+ content_hash = hashlib.md5(content.encode()).hexdigest()[:16]
130
+ return f"doc_{content_hash}"
131
+
132
+ def _retrieve_for_subquery(
133
+ self,
134
+ sub_query: str,
135
+ metadata_filter: Optional[Dict] = None
136
+ ) -> List[Dict]:
137
+ """
138
+ 針對單個子問題進行檢索
139
+
140
+ Args:
141
+ sub_query: 子問題
142
+ metadata_filter: 可選的 metadata 過濾條件
143
+
144
+ Returns:
145
+ 檢索結果列表
146
+ """
147
+ try:
148
+ results = self.rag_pipeline.query(
149
+ text=sub_query,
150
+ top_k=self.top_k_per_subquery,
151
+ metadata_filter=metadata_filter,
152
+ enable_rerank=True
153
+ )
154
+ return results
155
+ except Exception as e:
156
+ logger.error(f"⚠️ 檢索子問題 '{sub_query}' 時出錯: {e}")
157
+ return []
158
+
159
+ def _get_unique_documents(
160
+ self,
161
+ sub_queries: List[str],
162
+ metadata_filter: Optional[Dict] = None
163
+ ) -> List[Dict]:
164
+ """
165
+ 針對所有子問題進行檢索,並移除重複的檔案
166
+
167
+ Args:
168
+ sub_queries: 子問題列表
169
+ metadata_filter: 可選的 metadata 過濾條件
170
+
171
+ Returns:
172
+ 去重後的文檔列表
173
+ """
174
+ unique_docs = {}
175
+
176
+ if self.enable_parallel and len(sub_queries) > 1:
177
+ # 並行處理子查詢
178
+ logger.info(f"🔄 並行處理 {len(sub_queries)} 個子查詢...")
179
+ with ThreadPoolExecutor(max_workers=min(len(sub_queries), 5)) as executor:
180
+ future_to_query = {
181
+ executor.submit(self._retrieve_for_subquery, q, metadata_filter): q
182
+ for q in sub_queries
183
+ }
184
+
185
+ for future in as_completed(future_to_query):
186
+ sub_query = future_to_query[future]
187
+ try:
188
+ docs = future.result()
189
+ logger.debug(f"✅ 子問題 '{sub_query}' 找到 {len(docs)} 個結果")
190
+ for doc in docs:
191
+ doc_id = self._get_doc_id(doc)
192
+ if doc_id not in unique_docs:
193
+ unique_docs[doc_id] = doc
194
+ else:
195
+ # 如果已存在,保留分數更高的
196
+ existing_score = unique_docs[doc_id].get(
197
+ 'rerank_score',
198
+ unique_docs[doc_id].get('hybrid_score', 0)
199
+ )
200
+ new_score = doc.get(
201
+ 'rerank_score',
202
+ doc.get('hybrid_score', 0)
203
+ )
204
+ if new_score > existing_score:
205
+ unique_docs[doc_id] = doc
206
+ except Exception as e:
207
+ logger.error(f"⚠️ 處理子問題 '{sub_query}' 時出錯: {e}")
208
+ else:
209
+ # 串行處理
210
+ logger.info(f"🔄 串行處理 {len(sub_queries)} 個子查詢...")
211
+ for sub_query in sub_queries:
212
+ docs = self._retrieve_for_subquery(sub_query, metadata_filter)
213
+ logger.debug(f"✅ 子問題 '{sub_query}' 找到 {len(docs)} 個結果")
214
+ for doc in docs:
215
+ doc_id = self._get_doc_id(doc)
216
+ if doc_id not in unique_docs:
217
+ unique_docs[doc_id] = doc
218
+ else:
219
+ # 保留分數更高的
220
+ existing_score = unique_docs[doc_id].get(
221
+ 'rerank_score',
222
+ unique_docs[doc_id].get('hybrid_score', 0)
223
+ )
224
+ new_score = doc.get(
225
+ 'rerank_score',
226
+ doc.get('hybrid_score', 0)
227
+ )
228
+ if new_score > existing_score:
229
+ unique_docs[doc_id] = doc
230
+
231
+ # 按分數排序
232
+ result_list = list(unique_docs.values())
233
+ result_list.sort(
234
+ key=lambda x: x.get('rerank_score', x.get('hybrid_score', 0)),
235
+ reverse=True
236
+ )
237
+
238
+ return result_list
239
+
240
+ def query(
241
+ self,
242
+ question: str,
243
+ top_k: int = 5,
244
+ metadata_filter: Optional[Dict] = None,
245
+ return_sub_queries: bool = False
246
+ ) -> Dict:
247
+ """
248
+ 執行 Sub-query Decomposition RAG 查詢
249
+
250
+ Args:
251
+ question: 原始問題
252
+ top_k: 返回前 k 個結果
253
+ metadata_filter: 可選的 metadata 過濾條件
254
+ return_sub_queries: 是否在結果中包含子問題列表
255
+
256
+ Returns:
257
+ 包含檢索結果和統計資訊的字典
258
+ """
259
+ start_time = time.time()
260
+
261
+ # 第一步:產生子問題
262
+ logger.info(f"🔍 拆解問題: '{question}'")
263
+ sub_queries = self._generate_sub_queries(question)
264
+ logger.info(f"✅ 生成 {len(sub_queries)} 個子問題:")
265
+ for i, sq in enumerate(sub_queries, 1):
266
+ logger.info(f" {i}. {sq}")
267
+
268
+ # 第二步:檢索並去重
269
+ logger.info(f"📚 檢索相關文檔...")
270
+ docs = self._get_unique_documents(sub_queries, metadata_filter)
271
+ logger.info(f"✅ 找到 {len(docs)} 個唯一文檔(去重後)")
272
+
273
+ # 第三步:返回前 top_k 個結果
274
+ final_results = docs[:top_k]
275
+
276
+ elapsed_time = time.time() - start_time
277
+
278
+ result = {
279
+ "results": final_results,
280
+ "total_docs_found": len(docs),
281
+ "sub_queries": sub_queries if return_sub_queries else None,
282
+ "elapsed_time": elapsed_time
283
+ }
284
+
285
+ return result
286
+
287
+ def generate_answer(
288
+ self,
289
+ question: str,
290
+ formatter: PromptFormatter,
291
+ top_k: int = 5,
292
+ metadata_filter: Optional[Dict] = None,
293
+ document_type: str = "general",
294
+ return_sub_queries: bool = False
295
+ ) -> Dict:
296
+ """
297
+ 完整的 Sub-query Decomposition RAG 流程:檢索 + 生成答案
298
+
299
+ Args:
300
+ question: 原始問題
301
+ formatter: Prompt 格式化器
302
+ top_k: 返回前 k 個結果用於生成答案
303
+ metadata_filter: 可選的 metadata 過濾條件
304
+ document_type: 文檔類型 ("paper", "cv", "general")
305
+ return_sub_queries: 是否在結果中包含子問題列表
306
+
307
+ Returns:
308
+ 包含檢索結果、生成的答案和統計資訊的字典
309
+ """
310
+ # 檢索
311
+ retrieval_result = self.query(
312
+ question=question,
313
+ top_k=top_k,
314
+ metadata_filter=metadata_filter,
315
+ return_sub_queries=return_sub_queries
316
+ )
317
+
318
+ if not retrieval_result["results"]:
319
+ return {
320
+ **retrieval_result,
321
+ "answer": "抱歉,未找到相關文檔來回答此問題。",
322
+ "formatted_context": None
323
+ }
324
+
325
+ # 格式化上下文
326
+ formatted_context = formatter.format_context(
327
+ retrieval_result["results"],
328
+ document_type=document_type
329
+ )
330
+
331
+ # 創建 prompt
332
+ prompt = formatter.create_prompt(
333
+ question,
334
+ formatted_context,
335
+ document_type=document_type
336
+ )
337
+
338
+ # 生成回答
339
+ logger.info("🤖 生成回答中...")
340
+ answer_start = time.time()
341
+ try:
342
+ answer = self.llm.generate(
343
+ prompt=prompt,
344
+ temperature=0.7,
345
+ max_tokens=2048
346
+ )
347
+ answer_time = time.time() - answer_start
348
+ logger.info(f"✅ 回答生成完成(耗時: {answer_time:.2f}s)")
349
+ except Exception as e:
350
+ logger.error(f"❌ 生成回答時出錯: {e}")
351
+ answer = f"生成回答時出錯: {e}"
352
+ answer_time = time.time() - answer_start
353
+
354
+ return {
355
+ **retrieval_result,
356
+ "answer": answer,
357
+ "formatted_context": formatted_context,
358
+ "answer_time": answer_time,
359
+ "total_time": retrieval_result["elapsed_time"] + answer_time
360
+ }
361
+
src/triple_hybrid_rag.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Triple Hybrid RAG:融合 SubQuery + HyDE + Step-back Prompting
3
+ 結合三種技術的優勢,實現最強大的 RAG 系統
4
+ """
5
+ from typing import List, Dict, Optional
6
+ from .retrievers.reranker import RAGPipeline
7
+ from .retrievers.vector_retriever import VectorRetriever
8
+ from .prompt_formatter import PromptFormatter
9
+ from .llm_integration import OllamaLLM
10
+ import hashlib
11
+ import time
12
+ import logging
13
+ from concurrent.futures import ThreadPoolExecutor, as_completed
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class TripleHybridRAG:
19
+ """融合 SubQuery + HyDE + Step-back 的三重混合 RAG 系統"""
20
+
21
+ def __init__(
22
+ self,
23
+ rag_pipeline: RAGPipeline,
24
+ vector_retriever: VectorRetriever,
25
+ llm: OllamaLLM,
26
+ max_sub_queries: int = 3,
27
+ top_k_per_subquery: int = 5,
28
+ hypothetical_length: int = 200,
29
+ temperature_subquery: float = 0.3,
30
+ temperature_hyde: float = 0.7,
31
+ temperature_stepback: float = 0.3,
32
+ answer_temperature: float = 0.7,
33
+ enable_parallel: bool = True
34
+ ):
35
+ """
36
+ 初始化三重混合 RAG
37
+
38
+ Args:
39
+ rag_pipeline: RAG 管線實例
40
+ vector_retriever: 向量檢索器
41
+ llm: LLM 實例
42
+ max_sub_queries: 最多生成的子問題數量
43
+ top_k_per_subquery: 每個子問題檢索的結果數量
44
+ hypothetical_length: 假設性文檔目標長度(字符數)
45
+ temperature_subquery: 生成子問題的溫度(較低,更穩定)
46
+ temperature_hyde: 生成假設性文檔的溫度(較高,更多專業術語)
47
+ temperature_stepback: 生成抽象問題的溫度(較低,更穩定)
48
+ answer_temperature: 生成答案的溫度
49
+ enable_parallel: 是否並行處理
50
+ """
51
+ self.rag_pipeline = rag_pipeline
52
+ self.vector_retriever = vector_retriever
53
+ self.llm = llm
54
+ self.max_sub_queries = max_sub_queries
55
+ self.top_k_per_subquery = top_k_per_subquery
56
+ self.hypothetical_length = hypothetical_length
57
+ self.temperature_subquery = temperature_subquery
58
+ self.temperature_hyde = temperature_hyde
59
+ self.temperature_stepback = temperature_stepback
60
+ self.answer_temperature = answer_temperature
61
+ self.enable_parallel = enable_parallel
62
+
63
+ def _generate_sub_queries(self, question: str) -> List[str]:
64
+ """生成子問題(SubQuery)"""
65
+ is_chinese = PromptFormatter.detect_language(question) == "zh"
66
+
67
+ if is_chinese:
68
+ prompt = f"""你是一個專業助理。請將以下原始問題拆解成最多 {self.max_sub_queries} 個具體的子問題,以便進行資料搜尋。
69
+ 每個子問題應專注於原始問題的一個特定面向。請以換行符號分隔問題。
70
+
71
+ 原始問題: {question}
72
+
73
+ 子問題清單:"""
74
+ else:
75
+ prompt = f"""You are a professional assistant. Please decompose the following original question into at most {self.max_sub_queries} specific sub-questions for information retrieval.
76
+ Each sub-question should focus on a specific aspect of the original question. Please separate questions with newlines.
77
+
78
+ Original question: {question}
79
+
80
+ Sub-question list:"""
81
+
82
+ try:
83
+ response = self.llm.generate(
84
+ prompt=prompt,
85
+ temperature=self.temperature_subquery,
86
+ max_tokens=500
87
+ )
88
+
89
+ sub_queries = [
90
+ q.strip()
91
+ for q in response.strip().split("\n")
92
+ if q.strip() and not q.strip().startswith("#")
93
+ ]
94
+
95
+ # 移除編號前綴
96
+ cleaned_queries = []
97
+ for q in sub_queries:
98
+ q = q.lstrip("0123456789. )")
99
+ q = q.strip()
100
+ if q:
101
+ cleaned_queries.append(q)
102
+
103
+ cleaned_queries = cleaned_queries[:self.max_sub_queries]
104
+
105
+ if not cleaned_queries:
106
+ logger.warning("⚠️ 未生成子問題,使用原始問題")
107
+ cleaned_queries = [question]
108
+
109
+ return cleaned_queries
110
+
111
+ except Exception as e:
112
+ logger.error(f"⚠️ 生成子問題時出錯: {e}")
113
+ return [question]
114
+
115
+ def _generate_hypothetical_document(self, sub_query: str) -> str:
116
+ """為子問題生成假設性文檔(HyDE)"""
117
+ is_chinese = PromptFormatter.detect_language(sub_query) == "zh"
118
+
119
+ if is_chinese:
120
+ prompt = f"""請針對以下問題,寫出一段約 {self.hypothetical_length} 字的專業技術檔案內容。
121
+ 這段內容應包含該領域常見的專業術語與原理說明,以便用於後續的語義檢索。
122
+ 請使用專業的術語和概念,即使你對某些細節不確定,也要包含相關的專業詞彙。
123
+
124
+ 問題: {sub_query}
125
+
126
+ 專業��術內容:"""
127
+ else:
128
+ prompt = f"""Please write a professional technical document of approximately {self.hypothetical_length} words in response to the following question.
129
+ This content should include common professional terminology and principle explanations in this field, to be used for subsequent semantic retrieval.
130
+ Please use professional terms and concepts, and include relevant professional vocabulary even if you are uncertain about some details.
131
+
132
+ Question: {sub_query}
133
+
134
+ Professional technical content:"""
135
+
136
+ try:
137
+ hypothetical_doc = self.llm.generate(
138
+ prompt=prompt,
139
+ temperature=self.temperature_hyde,
140
+ max_tokens=500
141
+ )
142
+
143
+ hypothetical_doc = hypothetical_doc.strip()
144
+
145
+ if not hypothetical_doc:
146
+ logger.warning(f"⚠️ 子問題 '{sub_query}' 的假設性文檔為空,使用子問題本身")
147
+ return sub_query
148
+
149
+ return hypothetical_doc
150
+
151
+ except Exception as e:
152
+ logger.error(f"⚠️ 生成假設性文檔時出錯: {e}")
153
+ return sub_query
154
+
155
+ def _generate_step_back_question(self, question: str) -> str:
156
+ """生成 Step-back 抽象問題"""
157
+ is_chinese = PromptFormatter.detect_language(question) == "zh"
158
+
159
+ if is_chinese:
160
+ prompt = f"""你是一個資深專家。請將以下具體問題轉換為一個更抽象、更基礎的原理性問題。
161
+ 這個抽象問題應該幫助理解該領域的基礎概念和原理,而不是直接回答具體問題。
162
+
163
+ 具體問題: {question}
164
+
165
+ 請生成一個抽象問題,用於檢索相關的原理和背景知識:
166
+ """
167
+ else:
168
+ prompt = f"""You are a senior expert. Please convert the following specific question into a more abstract, fundamental question about principles and concepts.
169
+ This abstract question should help understand the basic concepts and principles in this field, rather than directly answering the specific question.
170
+
171
+ Specific question: {question}
172
+
173
+ Please generate an abstract question for retrieving relevant principles and background knowledge:
174
+ """
175
+
176
+ try:
177
+ abstract_question = self.llm.generate(
178
+ prompt=prompt,
179
+ temperature=self.temperature_stepback,
180
+ max_tokens=200
181
+ )
182
+
183
+ abstract_question = abstract_question.strip()
184
+
185
+ if not abstract_question:
186
+ logger.warning("⚠️ 生成的抽象問題為空,使用原始問題")
187
+ return question
188
+
189
+ return abstract_question
190
+
191
+ except Exception as e:
192
+ logger.error(f"⚠️ 生成抽象問題時出錯: {e}")
193
+ return question
194
+
195
+ def _get_doc_id(self, doc: Dict) -> str:
196
+ """生成文檔的唯一標識符"""
197
+ metadata = doc.get("metadata", {})
198
+ content = doc.get("content", "")
199
+
200
+ if "arxiv_id" in metadata and "chunk_index" in metadata:
201
+ return f"{metadata['arxiv_id']}_{metadata['chunk_index']}"
202
+ elif "file_path" in metadata and "chunk_index" in metadata:
203
+ return f"{metadata['file_path']}_{metadata['chunk_index']}"
204
+ else:
205
+ content_hash = hashlib.md5(content.encode()).hexdigest()[:16]
206
+ return f"doc_{content_hash}"
207
+
208
+ def _process_subquery_with_hyde(
209
+ self,
210
+ sub_query: str,
211
+ metadata_filter: Optional[Dict] = None
212
+ ) -> tuple:
213
+ """處理單個子問題:生成假設性文檔並檢索"""
214
+ try:
215
+ hypothetical_doc = self._generate_hypothetical_document(sub_query)
216
+ results = self.vector_retriever.retrieve(
217
+ query=hypothetical_doc,
218
+ top_k=self.top_k_per_subquery,
219
+ metadata_filter=metadata_filter
220
+ )
221
+ return results, hypothetical_doc
222
+ except Exception as e:
223
+ logger.error(f"⚠️ 處理子問題 '{sub_query}' 時出錯: {e}")
224
+ return [], ""
225
+
226
+ def query(
227
+ self,
228
+ question: str,
229
+ top_k: int = 5,
230
+ metadata_filter: Optional[Dict] = None,
231
+ return_sub_queries: bool = False,
232
+ return_hypothetical: bool = False,
233
+ return_abstract_question: bool = False
234
+ ) -> Dict:
235
+ """
236
+ 執行三重混合 RAG 檢索
237
+
238
+ 流程:
239
+ 1. 拆解成子問題(SubQuery)
240
+ 2. 對每個子問題生成假設性文檔並檢索(HyDE)
241
+ 3. 直接檢索原始問題(具體事實)
242
+ 4. 生成抽象問題並檢索(Step-back,抽象原理)
243
+ 5. 合併所有結果並去重
244
+ """
245
+ start_time = time.time()
246
+
247
+ # 第一步:生成子問題
248
+ logger.info(f"🔍 [SubQuery] 拆解問題: '{question}'")
249
+ sub_queries = self._generate_sub_queries(question)
250
+ logger.info(f"✅ 生成 {len(sub_queries)} 個子問題")
251
+
252
+ # 第二步:為每個子問題生成假設性文檔並檢索(HyDE)
253
+ logger.info(f"📚 [HyDE] 為每個子問題生成假設性文檔並檢索...")
254
+ subquery_results = []
255
+ hypothetical_docs = {}
256
+
257
+ if self.enable_parallel and len(sub_queries) > 1:
258
+ with ThreadPoolExecutor(max_workers=min(len(sub_queries), 5)) as executor:
259
+ future_to_query = {
260
+ executor.submit(self._process_subquery_with_hyde, sq, metadata_filter): sq
261
+ for sq in sub_queries
262
+ }
263
+
264
+ for future in as_completed(future_to_query):
265
+ sub_query = future_to_query[future]
266
+ try:
267
+ results, hypo_doc = future.result()
268
+ hypothetical_docs[sub_query] = hypo_doc
269
+ subquery_results.extend(results)
270
+ except Exception as e:
271
+ logger.error(f"⚠️ 處理子問題 '{sub_query}' 時出錯: {e}")
272
+ else:
273
+ for sub_query in sub_queries:
274
+ results, hypo_doc = self._process_subquery_with_hyde(sub_query, metadata_filter)
275
+ hypothetical_docs[sub_query] = hypo_doc
276
+ subquery_results.extend(results)
277
+
278
+ # 第三步:Step-back 雙軌檢索
279
+ logger.info(f"🔍 [Step-back] 執行雙軌檢索...")
280
+
281
+ if self.enable_parallel:
282
+ with ThreadPoolExecutor(max_workers=2) as executor:
283
+ direct_future = executor.submit(
284
+ self.vector_retriever.retrieve,
285
+ question, top_k, metadata_filter
286
+ )
287
+ abstract_question = self._generate_step_back_question(question)
288
+ step_back_future = executor.submit(
289
+ self.vector_retriever.retrieve,
290
+ abstract_question, top_k, metadata_filter
291
+ )
292
+
293
+ specific_results = direct_future.result()
294
+ abstract_results = step_back_future.result()
295
+ else:
296
+ specific_results = self.vector_retriever.retrieve(
297
+ query=question,
298
+ top_k=top_k,
299
+ metadata_filter=metadata_filter
300
+ )
301
+ abstract_question = self._generate_step_back_question(question)
302
+ abstract_results = self.vector_retriever.retrieve(
303
+ query=abstract_question,
304
+ top_k=top_k,
305
+ metadata_filter=metadata_filter
306
+ )
307
+
308
+ # 第四步:合併所有結果並去重
309
+ logger.info(f"🔄 合併並去重所有檢索結果...")
310
+ all_results = subquery_results + specific_results + abstract_results
311
+ unique_docs = {}
312
+
313
+ for doc in all_results:
314
+ doc_id = self._get_doc_id(doc)
315
+ if doc_id not in unique_docs:
316
+ unique_docs[doc_id] = doc
317
+ else:
318
+ # 保留分數更高的
319
+ existing_score = unique_docs[doc_id].get('score', 0)
320
+ new_score = doc.get('score', 0)
321
+ if new_score > existing_score:
322
+ unique_docs[doc_id] = doc
323
+
324
+ # 排序並返回前 top_k
325
+ result_list = list(unique_docs.values())
326
+ result_list.sort(key=lambda x: x.get('score', 0), reverse=True)
327
+ final_results = result_list[:top_k]
328
+
329
+ elapsed_time = time.time() - start_time
330
+ logger.info(
331
+ f"✅ 三重混合檢索完成(耗時: {elapsed_time:.2f}s)\n"
332
+ f" 子問題檢索: {len(subquery_results)} 個結果\n"
333
+ f" 具體事實: {len(specific_results)} 個結果\n"
334
+ f" 抽象原理: {len(abstract_results)} 個結果\n"
335
+ f" 去重後總計: {len(result_list)} 個,返回前 {len(final_results)} 個"
336
+ )
337
+
338
+ return {
339
+ "results": final_results,
340
+ "total_docs_found": len(result_list),
341
+ "sub_queries": sub_queries if return_sub_queries else None,
342
+ "hypothetical_documents": hypothetical_docs if return_hypothetical else None,
343
+ "abstract_question": abstract_question if return_abstract_question else None,
344
+ "subquery_results": subquery_results,
345
+ "specific_context": specific_results,
346
+ "abstract_context": abstract_results,
347
+ "question": question,
348
+ "elapsed_time": elapsed_time
349
+ }
350
+
351
+ def generate_answer(
352
+ self,
353
+ question: str,
354
+ formatter: PromptFormatter,
355
+ top_k: int = 5,
356
+ metadata_filter: Optional[Dict] = None,
357
+ document_type: str = "general",
358
+ return_sub_queries: bool = False,
359
+ return_hypothetical: bool = False,
360
+ return_abstract_question: bool = False
361
+ ) -> Dict:
362
+ """
363
+ 完整的三重混合 RAG 流程:檢索 + 生成答案
364
+ """
365
+ start_time = time.time()
366
+
367
+ # 檢索
368
+ retrieval_result = self.query(
369
+ question=question,
370
+ top_k=top_k,
371
+ metadata_filter=metadata_filter,
372
+ return_sub_queries=return_sub_queries,
373
+ return_hypothetical=return_hypothetical,
374
+ return_abstract_question=return_abstract_question
375
+ )
376
+
377
+ if not retrieval_result["results"]:
378
+ return {
379
+ **retrieval_result,
380
+ "answer": "抱歉,未找到相關文檔來回答此問題。",
381
+ "formatted_context": None,
382
+ "answer_time": 0.0,
383
+ "total_time": retrieval_result["elapsed_time"]
384
+ }
385
+
386
+ # 格式化三類上下文
387
+ subquery_context = formatter.format_context(
388
+ retrieval_result["subquery_results"][:top_k],
389
+ document_type=document_type
390
+ ) if retrieval_result.get("subquery_results") else "未找到相關的子問題檢索結果。"
391
+
392
+ specific_context = formatter.format_context(
393
+ retrieval_result["specific_context"],
394
+ document_type=document_type
395
+ ) if retrieval_result.get("specific_context") else "未找到相關的具體事實資料。"
396
+
397
+ abstract_context = formatter.format_context(
398
+ retrieval_result["abstract_context"],
399
+ document_type=document_type
400
+ ) if retrieval_result.get("abstract_context") else "未找到相關的基礎原理資料。"
401
+
402
+ # 創建融合提示詞(關鍵步驟)
403
+ is_chinese = PromptFormatter.detect_language(question) == "zh"
404
+
405
+ if is_chinese:
406
+ final_prompt = f"""你是一個資深專家。請結合以下三類資訊來回答使用者的具體問題。
407
+
408
+ 【基礎原理與背景】(來自 Step-back 抽象問題檢索)
409
+ {abstract_context}
410
+
411
+ 【具體事實資料】(來自直接問題檢索)
412
+ {specific_context}
413
+
414
+ 【子問題相關資料】(來自 SubQuery + HyDE 檢索)
415
+ {subquery_context}
416
+
417
+ 使用者問題:{question}
418
+
419
+ 請根據原理推導、結合具體事實,並參考子問題的相關資料,給出一個專業、全面且具備邏輯的回答:
420
+ """
421
+ else:
422
+ final_prompt = f"""You are a senior expert. Please answer the user's specific question by combining the following three types of information.
423
+
424
+ 【Fundamental Principles and Background】(from Step-back abstract question retrieval)
425
+ {abstract_context}
426
+
427
+ 【Specific Facts and Data】(from direct question retrieval)
428
+ {specific_context}
429
+
430
+ 【Sub-question Related Information】(from SubQuery + HyDE retrieval)
431
+ {subquery_context}
432
+
433
+ User question: {question}
434
+
435
+ Please provide a professional, comprehensive, and logical answer based on principles, facts, and sub-question related information:
436
+ """
437
+
438
+ # 生成回答
439
+ logger.info("🤖 生成回答中...")
440
+ answer_start = time.time()
441
+ try:
442
+ answer = self.llm.generate(
443
+ prompt=final_prompt,
444
+ temperature=self.answer_temperature,
445
+ max_tokens=2048
446
+ )
447
+ answer_time = time.time() - answer_start
448
+ logger.info(f"✅ 回答生成完成(耗時: {answer_time:.2f}s)")
449
+ except Exception as e:
450
+ logger.error(f"❌ 生成回答時出錯: {e}")
451
+ answer = f"生成回答時出錯: {e}"
452
+ answer_time = time.time() - answer_start
453
+
454
+ total_time = time.time() - start_time
455
+
456
+ return {
457
+ **retrieval_result,
458
+ "answer": answer,
459
+ "formatted_context": {
460
+ "subquery": subquery_context,
461
+ "specific": specific_context,
462
+ "abstract": abstract_context
463
+ },
464
+ "answer_time": answer_time,
465
+ "total_time": total_time
466
+ }
467
+
uv.lock CHANGED
@@ -179,6 +179,19 @@ wheels = [
179
  { url = "https://files.pythonhosted.org/packages/7f/9c/36c5c37947ebfb8c7f22e0eb6e4d188ee2d53aa3880f3f2744fb894f0cb1/anyio-4.12.0-py3-none-any.whl", hash = "sha256:dad2376a628f98eeca4881fc56cd06affd18f659b17a747d3ff0307ced94b1bb", size = 113362, upload-time = "2025-11-28T23:36:57.897Z" },
180
  ]
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  [[package]]
183
  name = "attrs"
184
  version = "25.4.0"
@@ -808,6 +821,9 @@ version = "0.1.0"
808
  source = { virtual = "." }
809
  dependencies = [
810
  { name = "accelerate" },
 
 
 
811
  { name = "einops" },
812
  { name = "fastapi" },
813
  { name = "google-api-python-client" },
@@ -820,9 +836,12 @@ dependencies = [
820
  { name = "langchain" },
821
  { name = "langchain-chroma" },
822
  { name = "langchain-community" },
 
823
  { name = "langchain-google-genai" },
824
  { name = "langchain-groq" },
 
825
  { name = "langchain-tavily" },
 
826
  { name = "langgraph" },
827
  { name = "langserve", extra = ["all"] },
828
  { name = "mcp" },
@@ -833,6 +852,7 @@ dependencies = [
833
  { name = "pillow" },
834
  { name = "pypdf" },
835
  { name = "python-dotenv" },
 
836
  { name = "sentence-transformers" },
837
  { name = "tavily-python" },
838
  { name = "torch" },
@@ -844,6 +864,9 @@ dependencies = [
844
  [package.metadata]
845
  requires-dist = [
846
  { name = "accelerate", specifier = ">=1.12.0" },
 
 
 
847
  { name = "einops", specifier = ">=0.8.1" },
848
  { name = "fastapi", specifier = ">=0.124.2" },
849
  { name = "google-api-python-client", specifier = ">=2.187.0" },
@@ -856,9 +879,12 @@ requires-dist = [
856
  { name = "langchain", specifier = ">=1.1.3" },
857
  { name = "langchain-chroma", specifier = ">=1.0.0" },
858
  { name = "langchain-community", specifier = ">=0.4.1" },
 
859
  { name = "langchain-google-genai", specifier = ">=4.0.0" },
860
  { name = "langchain-groq", specifier = ">=1.1.0" },
 
861
  { name = "langchain-tavily", specifier = ">=0.2.13" },
 
862
  { name = "langgraph", specifier = ">=1.0.4" },
863
  { name = "langserve", extras = ["all"], specifier = ">=0.3.3" },
864
  { name = "mcp", specifier = ">=1.24.0" },
@@ -869,6 +895,7 @@ requires-dist = [
869
  { name = "pillow", specifier = ">=12.0.0" },
870
  { name = "pypdf", specifier = ">=6.4.1" },
871
  { name = "python-dotenv", specifier = ">=1.2.1" },
 
872
  { name = "sentence-transformers", specifier = ">=5.2.0" },
873
  { name = "tavily-python", specifier = ">=0.7.14" },
874
  { name = "torch", specifier = ">=2.9.1" },
@@ -908,6 +935,7 @@ wheels = [
908
  ]
909
 
910
  [[package]]
 
911
  name = "dnspython"
912
  version = "2.8.0"
913
  source = { registry = "https://pypi.org/simple" }
@@ -932,6 +960,14 @@ source = { registry = "https://pypi.org/simple" }
932
  sdist = { url = "https://files.pythonhosted.org/packages/ae/b6/03bb70946330e88ffec97aefd3ea75ba575cb2e762061e0e62a213befee8/docutils-0.22.4.tar.gz", hash = "sha256:4db53b1fde9abecbb74d91230d32ab626d94f6badfc575d6db9194a49df29968", size = 2291750, upload-time = "2025-12-18T19:00:26.443Z" }
933
  wheels = [
934
  { url = "https://files.pythonhosted.org/packages/02/10/5da547df7a391dcde17f59520a231527b8571e6f46fc8efb02ccb370ab12/docutils-0.22.4-py3-none-any.whl", hash = "sha256:d0013f540772d1420576855455d050a2180186c91c15779301ac2ccb3eeb68de", size = 633196, upload-time = "2025-12-18T19:00:18.077Z" },
 
 
 
 
 
 
 
 
935
  ]
936
 
937
  [[package]]
@@ -990,6 +1026,7 @@ wheels = [
990
  ]
991
 
992
  [[package]]
 
993
  name = "fastmcp"
994
  version = "2.13.0"
995
  source = { registry = "https://pypi.org/simple" }
@@ -1012,6 +1049,17 @@ dependencies = [
1012
  sdist = { url = "https://files.pythonhosted.org/packages/bc/3b/c30af894db2c3ec439d0e4168ba7ce705474cabdd0a599033ad9a19ad977/fastmcp-2.13.0.tar.gz", hash = "sha256:57f7b7503363e1babc0d1a13af18252b80366a409e1de85f1256cce66a4bee35", size = 7767346, upload-time = "2025-10-25T12:54:10.957Z" }
1013
  wheels = [
1014
  { url = "https://files.pythonhosted.org/packages/c0/7f/09942135f506953fc61bb81b9e5eaf50a8eea923b83d9135bd959168ef2d/fastmcp-2.13.0-py3-none-any.whl", hash = "sha256:bdff1399d3b7ebb79286edfd43eb660182432514a5ab8e4cbfb45f1d841d2aa0", size = 367134, upload-time = "2025-10-25T12:54:09.284Z" },
 
 
 
 
 
 
 
 
 
 
 
1015
  ]
1016
 
1017
  [[package]]
@@ -2116,6 +2164,19 @@ wheels = [
2116
  { url = "https://files.pythonhosted.org/packages/83/bd/9df897cbc98290bf71140104ee5b9777cf5291afb80333aa7da5a497339b/langchain_core-1.2.5-py3-none-any.whl", hash = "sha256:3255944ef4e21b2551facb319bfc426057a40247c0a05de5bd6f2fc021fbfa34", size = 484851, upload-time = "2025-12-22T23:45:30.525Z" },
2117
  ]
2118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2119
  [[package]]
2120
  name = "langchain-google-genai"
2121
  version = "4.1.1"
@@ -2144,6 +2205,19 @@ wheels = [
2144
  { url = "https://files.pythonhosted.org/packages/af/4a/3d6227a16fe9f79968414b50e50869519378b20653805e2e8fab283908e6/langchain_groq-1.1.1-py3-none-any.whl", hash = "sha256:1c6d5146f60205dcde09d7e47bb5291c295d3f0c7bcd2417e4d3a73a04bd1050", size = 19039, upload-time = "2025-12-12T22:00:45.86Z" },
2145
  ]
2146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2147
  [[package]]
2148
  name = "langchain-tavily"
2149
  version = "0.2.16"
@@ -2989,6 +3063,19 @@ wheels = [
2989
  { url = "https://files.pythonhosted.org/packages/be/9c/92789c596b8df838baa98fa71844d84283302f7604ed565dafe5a6b5041a/oauthlib-3.3.1-py3-none-any.whl", hash = "sha256:88119c938d2b8fb88561af5f6ee0eec8cc8d552b7bb1f712743136eb7523b7a1", size = 160065, upload-time = "2025-06-19T22:48:06.508Z" },
2990
  ]
2991
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2992
  [[package]]
2993
  name = "onnxruntime"
2994
  version = "1.23.2"
@@ -4101,6 +4188,18 @@ wheels = [
4101
  { url = "https://files.pythonhosted.org/packages/f1/12/de94a39c2ef588c7e6455cfbe7343d3b2dc9d6b6b2f40c4c6565744c873d/pyyaml-6.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:ebc55a14a21cb14062aa4162f906cd962b28e2e9ea38f9b4391244cd8de4ae0b", size = 149341, upload-time = "2025-09-25T21:32:56.828Z" },
4102
  ]
4103
 
 
 
 
 
 
 
 
 
 
 
 
 
4104
  [[package]]
4105
  name = "referencing"
4106
  version = "0.36.2"
@@ -4585,6 +4684,12 @@ wheels = [
4585
  { url = "https://files.pythonhosted.org/packages/a3/dc/17031897dae0efacfea57dfd3a82fdd2a2aeb58e0ff71b77b87e44edc772/setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922", size = 1201486, upload-time = "2025-05-27T00:56:49.664Z" },
4586
  ]
4587
 
 
 
 
 
 
 
4588
  [[package]]
4589
  name = "shellingham"
4590
  version = "1.5.4"
 
179
  { url = "https://files.pythonhosted.org/packages/7f/9c/36c5c37947ebfb8c7f22e0eb6e4d188ee2d53aa3880f3f2744fb894f0cb1/anyio-4.12.0-py3-none-any.whl", hash = "sha256:dad2376a628f98eeca4881fc56cd06affd18f659b17a747d3ff0307ced94b1bb", size = 113362, upload-time = "2025-11-28T23:36:57.897Z" },
180
  ]
181
 
182
+ [[package]]
183
+ name = "arxiv"
184
+ version = "2.3.1"
185
+ source = { registry = "https://pypi.org/simple" }
186
+ dependencies = [
187
+ { name = "feedparser" },
188
+ { name = "requests" },
189
+ ]
190
+ sdist = { url = "https://files.pythonhosted.org/packages/dd/95/65e38ddfb54762a8f1777bbe80da2cebf7941376e67a2212de487d9372db/arxiv-2.3.1.tar.gz", hash = "sha256:08567185dfc102c8d349de4b9e84dfde0af46d6402486e3009afc90f8ccf9709", size = 16692, upload-time = "2025-11-13T06:22:59.853Z" }
191
+ wheels = [
192
+ { url = "https://files.pythonhosted.org/packages/90/7f/340847023184305a6378d75ec71e1dd38a942dfe71b7c29314b8fbe26948/arxiv-2.3.1-py3-none-any.whl", hash = "sha256:eb5a0b76808cc0a16de0c1448df0f927a3cf576096686d8e335a98b8872df1be", size = 11565, upload-time = "2025-11-13T06:22:58.662Z" },
193
+ ]
194
+
195
  [[package]]
196
  name = "attrs"
197
  version = "25.4.0"
 
821
  source = { virtual = "." }
822
  dependencies = [
823
  { name = "accelerate" },
824
+ { name = "arxiv" },
825
+ { name = "chromadb" },
826
+ { name = "docx2txt" },
827
  { name = "einops" },
828
  { name = "fastapi" },
829
  { name = "google-api-python-client" },
 
836
  { name = "langchain" },
837
  { name = "langchain-chroma" },
838
  { name = "langchain-community" },
839
+ { name = "langchain-experimental" },
840
  { name = "langchain-google-genai" },
841
  { name = "langchain-groq" },
842
+ { name = "langchain-ollama" },
843
  { name = "langchain-tavily" },
844
+ { name = "langchain-text-splitters" },
845
  { name = "langgraph" },
846
  { name = "langserve", extra = ["all"] },
847
  { name = "mcp" },
 
852
  { name = "pillow" },
853
  { name = "pypdf" },
854
  { name = "python-dotenv" },
855
+ { name = "rank-bm25" },
856
  { name = "sentence-transformers" },
857
  { name = "tavily-python" },
858
  { name = "torch" },
 
864
  [package.metadata]
865
  requires-dist = [
866
  { name = "accelerate", specifier = ">=1.12.0" },
867
+ { name = "arxiv", specifier = ">=2.3.1" },
868
+ { name = "chromadb", specifier = ">=0.4.22" },
869
+ { name = "docx2txt", specifier = ">=0.8" },
870
  { name = "einops", specifier = ">=0.8.1" },
871
  { name = "fastapi", specifier = ">=0.124.2" },
872
  { name = "google-api-python-client", specifier = ">=2.187.0" },
 
879
  { name = "langchain", specifier = ">=1.1.3" },
880
  { name = "langchain-chroma", specifier = ">=1.0.0" },
881
  { name = "langchain-community", specifier = ">=0.4.1" },
882
+ { name = "langchain-experimental", specifier = ">=0.0.50" },
883
  { name = "langchain-google-genai", specifier = ">=4.0.0" },
884
  { name = "langchain-groq", specifier = ">=1.1.0" },
885
+ { name = "langchain-ollama", specifier = ">=0.1.0" },
886
  { name = "langchain-tavily", specifier = ">=0.2.13" },
887
+ { name = "langchain-text-splitters", specifier = ">=0.0.1" },
888
  { name = "langgraph", specifier = ">=1.0.4" },
889
  { name = "langserve", extras = ["all"], specifier = ">=0.3.3" },
890
  { name = "mcp", specifier = ">=1.24.0" },
 
895
  { name = "pillow", specifier = ">=12.0.0" },
896
  { name = "pypdf", specifier = ">=6.4.1" },
897
  { name = "python-dotenv", specifier = ">=1.2.1" },
898
+ { name = "rank-bm25", specifier = ">=0.2.2" },
899
  { name = "sentence-transformers", specifier = ">=5.2.0" },
900
  { name = "tavily-python", specifier = ">=0.7.14" },
901
  { name = "torch", specifier = ">=2.9.1" },
 
935
  ]
936
 
937
  [[package]]
938
+ <<<<<<< HEAD
939
  name = "dnspython"
940
  version = "2.8.0"
941
  source = { registry = "https://pypi.org/simple" }
 
960
  sdist = { url = "https://files.pythonhosted.org/packages/ae/b6/03bb70946330e88ffec97aefd3ea75ba575cb2e762061e0e62a213befee8/docutils-0.22.4.tar.gz", hash = "sha256:4db53b1fde9abecbb74d91230d32ab626d94f6badfc575d6db9194a49df29968", size = 2291750, upload-time = "2025-12-18T19:00:26.443Z" }
961
  wheels = [
962
  { url = "https://files.pythonhosted.org/packages/02/10/5da547df7a391dcde17f59520a231527b8571e6f46fc8efb02ccb370ab12/docutils-0.22.4-py3-none-any.whl", hash = "sha256:d0013f540772d1420576855455d050a2180186c91c15779301ac2ccb3eeb68de", size = 633196, upload-time = "2025-12-18T19:00:18.077Z" },
963
+ =======
964
+ name = "docx2txt"
965
+ version = "0.9"
966
+ source = { registry = "https://pypi.org/simple" }
967
+ sdist = { url = "https://files.pythonhosted.org/packages/ea/07/4486a038624e885e227fe79111914c01f55aa70a51920ff1a7f2bd216d10/docx2txt-0.9.tar.gz", hash = "sha256:18013f6229b14909028b19aa7bf4f8f3d6e4632d7b089ab29f7f0a4d1f660e28", size = 3613, upload-time = "2025-03-24T20:59:25.21Z" }
968
+ wheels = [
969
+ { url = "https://files.pythonhosted.org/packages/d6/51/756e71bec48ece0ecc2a10e921ef2756e197dcb7e478f2b43673b6683902/docx2txt-0.9-py3-none-any.whl", hash = "sha256:e3718c0653fd6f2fcf4b51b02a61452ad1c38a4c163bcf0a6fd9486cd38f529a", size = 4025, upload-time = "2025-03-24T20:59:24.394Z" },
970
+ >>>>>>> 5beccbe9dfa0ef53e4123976ad54e2f1c28b72f8
971
  ]
972
 
973
  [[package]]
 
1026
  ]
1027
 
1028
  [[package]]
1029
+ <<<<<<< HEAD
1030
  name = "fastmcp"
1031
  version = "2.13.0"
1032
  source = { registry = "https://pypi.org/simple" }
 
1049
  sdist = { url = "https://files.pythonhosted.org/packages/bc/3b/c30af894db2c3ec439d0e4168ba7ce705474cabdd0a599033ad9a19ad977/fastmcp-2.13.0.tar.gz", hash = "sha256:57f7b7503363e1babc0d1a13af18252b80366a409e1de85f1256cce66a4bee35", size = 7767346, upload-time = "2025-10-25T12:54:10.957Z" }
1050
  wheels = [
1051
  { url = "https://files.pythonhosted.org/packages/c0/7f/09942135f506953fc61bb81b9e5eaf50a8eea923b83d9135bd959168ef2d/fastmcp-2.13.0-py3-none-any.whl", hash = "sha256:bdff1399d3b7ebb79286edfd43eb660182432514a5ab8e4cbfb45f1d841d2aa0", size = 367134, upload-time = "2025-10-25T12:54:09.284Z" },
1052
+ =======
1053
+ name = "feedparser"
1054
+ version = "6.0.12"
1055
+ source = { registry = "https://pypi.org/simple" }
1056
+ dependencies = [
1057
+ { name = "sgmllib3k" },
1058
+ ]
1059
+ sdist = { url = "https://files.pythonhosted.org/packages/dc/79/db7edb5e77d6dfbc54d7d9df72828be4318275b2e580549ff45a962f6461/feedparser-6.0.12.tar.gz", hash = "sha256:64f76ce90ae3e8ef5d1ede0f8d3b50ce26bcce71dd8ae5e82b1cd2d4a5f94228", size = 286579, upload-time = "2025-09-10T13:33:59.486Z" }
1060
+ wheels = [
1061
+ { url = "https://files.pythonhosted.org/packages/4e/eb/c96d64137e29ae17d83ad2552470bafe3a7a915e85434d9942077d7fd011/feedparser-6.0.12-py3-none-any.whl", hash = "sha256:6bbff10f5a52662c00a2e3f86a38928c37c48f77b3c511aedcd51de933549324", size = 81480, upload-time = "2025-09-10T13:33:58.022Z" },
1062
+ >>>>>>> 5beccbe9dfa0ef53e4123976ad54e2f1c28b72f8
1063
  ]
1064
 
1065
  [[package]]
 
2164
  { url = "https://files.pythonhosted.org/packages/83/bd/9df897cbc98290bf71140104ee5b9777cf5291afb80333aa7da5a497339b/langchain_core-1.2.5-py3-none-any.whl", hash = "sha256:3255944ef4e21b2551facb319bfc426057a40247c0a05de5bd6f2fc021fbfa34", size = 484851, upload-time = "2025-12-22T23:45:30.525Z" },
2165
  ]
2166
 
2167
+ [[package]]
2168
+ name = "langchain-experimental"
2169
+ version = "0.4.1"
2170
+ source = { registry = "https://pypi.org/simple" }
2171
+ dependencies = [
2172
+ { name = "langchain-community" },
2173
+ { name = "langchain-core" },
2174
+ ]
2175
+ sdist = { url = "https://files.pythonhosted.org/packages/a2/ec/6fe7b2e3c105b4f4fc6b943d8fc1b5b10f883429edc36c58a09fc2e28419/langchain_experimental-0.4.1.tar.gz", hash = "sha256:ab6b19a0b98fbc15225fbfcf096176fec339b7e3e930bcf328bb717985fc1da5", size = 170449, upload-time = "2025-12-11T05:30:48.455Z" }
2176
+ wheels = [
2177
+ { url = "https://files.pythonhosted.org/packages/24/fa/fb2c8b6418e1c9ef50c82b3b6e0184bce321582577240bb4b8ed3274a4aa/langchain_experimental-0.4.1-py3-none-any.whl", hash = "sha256:b6ee2f42b50aaadb45e581439ecf5ee50f3a6a0986d52e74d1e64721309e387d", size = 210096, upload-time = "2025-12-11T05:30:47.234Z" },
2178
+ ]
2179
+
2180
  [[package]]
2181
  name = "langchain-google-genai"
2182
  version = "4.1.1"
 
2205
  { url = "https://files.pythonhosted.org/packages/af/4a/3d6227a16fe9f79968414b50e50869519378b20653805e2e8fab283908e6/langchain_groq-1.1.1-py3-none-any.whl", hash = "sha256:1c6d5146f60205dcde09d7e47bb5291c295d3f0c7bcd2417e4d3a73a04bd1050", size = 19039, upload-time = "2025-12-12T22:00:45.86Z" },
2206
  ]
2207
 
2208
+ [[package]]
2209
+ name = "langchain-ollama"
2210
+ version = "1.0.1"
2211
+ source = { registry = "https://pypi.org/simple" }
2212
+ dependencies = [
2213
+ { name = "langchain-core" },
2214
+ { name = "ollama" },
2215
+ ]
2216
+ sdist = { url = "https://files.pythonhosted.org/packages/73/51/72cd04d74278f3575f921084f34280e2f837211dc008c9671c268c578afe/langchain_ollama-1.0.1.tar.gz", hash = "sha256:e37880c2f41cdb0895e863b1cfd0c2c840a117868b3f32e44fef42569e367443", size = 153850, upload-time = "2025-12-12T21:48:28.68Z" }
2217
+ wheels = [
2218
+ { url = "https://files.pythonhosted.org/packages/e3/46/f2907da16dc5a5a6c679f83b7de21176178afad8d2ca635a581429580ef6/langchain_ollama-1.0.1-py3-none-any.whl", hash = "sha256:37eb939a4718a0255fe31e19fbb0def044746c717b01b97d397606ebc3e9b440", size = 29207, upload-time = "2025-12-12T21:48:27.832Z" },
2219
+ ]
2220
+
2221
  [[package]]
2222
  name = "langchain-tavily"
2223
  version = "0.2.16"
 
3063
  { url = "https://files.pythonhosted.org/packages/be/9c/92789c596b8df838baa98fa71844d84283302f7604ed565dafe5a6b5041a/oauthlib-3.3.1-py3-none-any.whl", hash = "sha256:88119c938d2b8fb88561af5f6ee0eec8cc8d552b7bb1f712743136eb7523b7a1", size = 160065, upload-time = "2025-06-19T22:48:06.508Z" },
3064
  ]
3065
 
3066
+ [[package]]
3067
+ name = "ollama"
3068
+ version = "0.6.1"
3069
+ source = { registry = "https://pypi.org/simple" }
3070
+ dependencies = [
3071
+ { name = "httpx" },
3072
+ { name = "pydantic" },
3073
+ ]
3074
+ sdist = { url = "https://files.pythonhosted.org/packages/9d/5a/652dac4b7affc2b37b95386f8ae78f22808af09d720689e3d7a86b6ed98e/ollama-0.6.1.tar.gz", hash = "sha256:478c67546836430034b415ed64fa890fd3d1ff91781a9d548b3325274e69d7c6", size = 51620, upload-time = "2025-11-13T23:02:17.416Z" }
3075
+ wheels = [
3076
+ { url = "https://files.pythonhosted.org/packages/47/4f/4a617ee93d8208d2bcf26b2d8b9402ceaed03e3853c754940e2290fed063/ollama-0.6.1-py3-none-any.whl", hash = "sha256:fc4c984b345735c5486faeee67d8a265214a31cbb828167782dc642ce0a2bf8c", size = 14354, upload-time = "2025-11-13T23:02:16.292Z" },
3077
+ ]
3078
+
3079
  [[package]]
3080
  name = "onnxruntime"
3081
  version = "1.23.2"
 
4188
  { url = "https://files.pythonhosted.org/packages/f1/12/de94a39c2ef588c7e6455cfbe7343d3b2dc9d6b6b2f40c4c6565744c873d/pyyaml-6.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:ebc55a14a21cb14062aa4162f906cd962b28e2e9ea38f9b4391244cd8de4ae0b", size = 149341, upload-time = "2025-09-25T21:32:56.828Z" },
4189
  ]
4190
 
4191
+ [[package]]
4192
+ name = "rank-bm25"
4193
+ version = "0.2.2"
4194
+ source = { registry = "https://pypi.org/simple" }
4195
+ dependencies = [
4196
+ { name = "numpy" },
4197
+ ]
4198
+ sdist = { url = "https://files.pythonhosted.org/packages/fc/0a/f9579384aa017d8b4c15613f86954b92a95a93d641cc849182467cf0bb3b/rank_bm25-0.2.2.tar.gz", hash = "sha256:096ccef76f8188563419aaf384a02f0ea459503fdf77901378d4fd9d87e5e51d", size = 8347, upload-time = "2022-02-16T12:10:52.196Z" }
4199
+ wheels = [
4200
+ { url = "https://files.pythonhosted.org/packages/2a/21/f691fb2613100a62b3fa91e9988c991e9ca5b89ea31c0d3152a3210344f9/rank_bm25-0.2.2-py3-none-any.whl", hash = "sha256:7bd4a95571adadfc271746fa146a4bcfd89c0cf731e49c3d1ad863290adbe8ae", size = 8584, upload-time = "2022-02-16T12:10:50.626Z" },
4201
+ ]
4202
+
4203
  [[package]]
4204
  name = "referencing"
4205
  version = "0.36.2"
 
4684
  { url = "https://files.pythonhosted.org/packages/a3/dc/17031897dae0efacfea57dfd3a82fdd2a2aeb58e0ff71b77b87e44edc772/setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922", size = 1201486, upload-time = "2025-05-27T00:56:49.664Z" },
4685
  ]
4686
 
4687
+ [[package]]
4688
+ name = "sgmllib3k"
4689
+ version = "1.0.0"
4690
+ source = { registry = "https://pypi.org/simple" }
4691
+ sdist = { url = "https://files.pythonhosted.org/packages/9e/bd/3704a8c3e0942d711c1299ebf7b9091930adae6675d7c8f476a7ce48653c/sgmllib3k-1.0.0.tar.gz", hash = "sha256:7868fb1c8bfa764c1ac563d3cf369c381d1325d36124933a726f29fcdaa812e9", size = 5750, upload-time = "2010-08-24T14:33:52.445Z" }
4692
+
4693
  [[package]]
4694
  name = "shellingham"
4695
  version = "1.5.4"