Spaces:
Sleeping
Sleeping
Add main source
Browse files- .gitignore +9 -0
- Deep_Agent_Gradio_RAG_localLLM_main.py +1 -1
- OLLAMA_SETUP.md +153 -0
- PRIVATE_FILE_RAG_GUIDE.md +129 -0
- README.md +117 -0
- deep_agent_rag/config.py +7 -0
- deep_agent_rag/rag/adaptive_rag_selector.py +292 -0
- deep_agent_rag/rag/llm_adapter.py +151 -0
- deep_agent_rag/rag/private_file_rag.py +0 -0
- deep_agent_rag/ui/calendar_interface.py +653 -0
- deep_agent_rag/ui/email_interface.py +259 -0
- deep_agent_rag/ui/gradio_interface.py +13 -891
- deep_agent_rag/ui/private_file_rag_interface.py +663 -0
- deep_agent_rag/utils/llm_utils.py +68 -45
- pyproject.toml +11 -0
- src/__init__.py +37 -0
- src/document_processor.py +590 -0
- src/hybrid_subquery_hyde_rag.py +399 -0
- src/hyde_rag.py +235 -0
- src/llm_integration.py +246 -0
- src/prompt_formatter.py +395 -0
- src/retrievers/__init__.py +17 -0
- src/retrievers/base.py +32 -0
- src/retrievers/bm25_retriever.py +127 -0
- src/retrievers/hybrid_search.py +298 -0
- src/retrievers/reranker.py +448 -0
- src/retrievers/vector_retriever.py +254 -0
- src/step_back_rag.py +305 -0
- src/subquery_rag.py +361 -0
- src/triple_hybrid_rag.py +467 -0
- uv.lock +105 -0
.gitignore
CHANGED
|
@@ -19,3 +19,12 @@ token.json
|
|
| 19 |
|
| 20 |
*test_check*
|
| 21 |
.DS_Store
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
*test_check*
|
| 21 |
.DS_Store
|
| 22 |
+
|
| 23 |
+
<<<<<<< HEAD
|
| 24 |
+
.cursor/*
|
| 25 |
+
chroma_db*/*
|
| 26 |
+
=======
|
| 27 |
+
chroma_db*/
|
| 28 |
+
|
| 29 |
+
.cursor/*
|
| 30 |
+
>>>>>>> 8862b07082bc878942f9e22816227a2e9a718b23
|
Deep_Agent_Gradio_RAG_localLLM_main.py
CHANGED
|
@@ -25,7 +25,7 @@ def main():
|
|
| 25 |
"""主函數:初始化系統並啟動 Gradio 界面"""
|
| 26 |
print("\n🚀 Deep Research Agent with RAG (Local MLX Edition) 啟動!")
|
| 27 |
print("💡 本系統整合了:股票查詢、網路搜尋、PDF 知識庫查詢功能\n")
|
| 28 |
-
|
| 29 |
|
| 30 |
# 初始化 Parlant SDK
|
| 31 |
print("🔧 正在初始化 Parlant SDK...")
|
|
|
|
| 25 |
"""主函數:初始化系統並啟動 Gradio 界面"""
|
| 26 |
print("\n🚀 Deep Research Agent with RAG (Local MLX Edition) 啟動!")
|
| 27 |
print("💡 本系統整合了:股票查詢、網路搜尋、PDF 知識庫查詢功能\n")
|
| 28 |
+
|
| 29 |
|
| 30 |
# 初始化 Parlant SDK
|
| 31 |
print("🔧 正在初始化 Parlant SDK...")
|
OLLAMA_SETUP.md
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ollama 設置指南
|
| 2 |
+
|
| 3 |
+
本指南說明如何在 Deep Agentic AI Tool 中設置和使用 Ollama,特別是 Llama 3.2 3B 模型。
|
| 4 |
+
|
| 5 |
+
## 📋 前置需求
|
| 6 |
+
|
| 7 |
+
- macOS 或 Linux 系統
|
| 8 |
+
- 至少 16GB 記憶體(推薦)
|
| 9 |
+
- Python >= 3.13
|
| 10 |
+
|
| 11 |
+
## 🚀 安裝步驟
|
| 12 |
+
|
| 13 |
+
### 1. 安裝 Ollama
|
| 14 |
+
|
| 15 |
+
**macOS:**
|
| 16 |
+
```bash
|
| 17 |
+
brew install ollama
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
或從官網下載:https://ollama.com
|
| 21 |
+
|
| 22 |
+
**Linux:**
|
| 23 |
+
```bash
|
| 24 |
+
curl -fsSL https://ollama.com/install.sh | sh
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
### 2. 下載 Llama 3.2 模型
|
| 28 |
+
|
| 29 |
+
```bash
|
| 30 |
+
ollama pull llama3.2:3b
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
這會下載約 2GB 的模型文件。
|
| 34 |
+
|
| 35 |
+
### 3. 啟動 Ollama 服務
|
| 36 |
+
|
| 37 |
+
Ollama 通常會自動啟動,如果需要手動啟動:
|
| 38 |
+
|
| 39 |
+
```bash
|
| 40 |
+
ollama serve
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
服務預設運行在 `http://localhost:11434`
|
| 44 |
+
|
| 45 |
+
### 4. 驗證安裝
|
| 46 |
+
|
| 47 |
+
測試模型是否可用:
|
| 48 |
+
|
| 49 |
+
```bash
|
| 50 |
+
ollama run llama3.2:3b "Hello, how are you?"
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
## ⚙️ 配置專案
|
| 54 |
+
|
| 55 |
+
### 1. 更新環境變數
|
| 56 |
+
|
| 57 |
+
在專案根目錄的 `.env` 文件中添加:
|
| 58 |
+
|
| 59 |
+
```env
|
| 60 |
+
# 啟用 Ollama
|
| 61 |
+
USE_OLLAMA=true
|
| 62 |
+
OLLAMA_BASE_URL=http://localhost:11434
|
| 63 |
+
OLLAMA_MODEL=llama3.2:3b
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
### 2. 可選配置
|
| 67 |
+
|
| 68 |
+
如果需要使用其他 Ollama 模型,可以修改:
|
| 69 |
+
|
| 70 |
+
```env
|
| 71 |
+
OLLAMA_MODEL=qwen2.5:7b # 使用 Qwen2.5
|
| 72 |
+
OLLAMA_MODEL=llama3.1:8b # 使用 Llama 3.1
|
| 73 |
+
OLLAMA_MODEL=deepseek-r1:7b # 使用 DeepSeek-R1
|
| 74 |
+
OLLAMA_MODEL=mistral:7b # 使用 Mistral
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
## 🎯 使用方式
|
| 78 |
+
|
| 79 |
+
系統會按照以下優先順序自動選擇 LLM:
|
| 80 |
+
|
| 81 |
+
1. **Groq API**(如果配置了 `GROQ_API_KEY`)
|
| 82 |
+
2. **Ollama**(如果 `USE_OLLAMA=true` 且服務可用)
|
| 83 |
+
3. **MLX 模型**(備援選項)
|
| 84 |
+
|
| 85 |
+
當 Groq API 額度用完時,系統會自動切換到 Ollama(如果啟用),否則使用 MLX 模型。
|
| 86 |
+
|
| 87 |
+
## 🔍 檢查當前使用的模型
|
| 88 |
+
|
| 89 |
+
啟動應用後,查看控制台輸出:
|
| 90 |
+
|
| 91 |
+
- `✅ 使用 Groq API (優先)` - 使用 Groq API
|
| 92 |
+
- `✅ 使用 Ollama 模型 (llama3.2:3b)` - 使用 Ollama
|
| 93 |
+
- `ℹ️ 使用本地 MLX 模型` - 使用 MLX 模型
|
| 94 |
+
|
| 95 |
+
## 🐛 故障排除
|
| 96 |
+
|
| 97 |
+
### Ollama 服務無法連接
|
| 98 |
+
|
| 99 |
+
**問題:** `⚠️ Ollama 初始化失敗: Connection refused`
|
| 100 |
+
|
| 101 |
+
**解決方案:**
|
| 102 |
+
1. 確認 Ollama 服務正在運行:`ollama serve`
|
| 103 |
+
2. 檢查端口是否被占用:`lsof -i :11434`
|
| 104 |
+
3. 確認 `OLLAMA_BASE_URL` 配置正確
|
| 105 |
+
|
| 106 |
+
### 模型找不到
|
| 107 |
+
|
| 108 |
+
**問題:** `⚠️ Ollama 初始化失敗: model not found`
|
| 109 |
+
|
| 110 |
+
**解決方案:**
|
| 111 |
+
```bash
|
| 112 |
+
# 下載模型
|
| 113 |
+
ollama pull llama3.2:3b
|
| 114 |
+
|
| 115 |
+
# 列出已安裝的模型
|
| 116 |
+
ollama list
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
### 記憶體不足
|
| 120 |
+
|
| 121 |
+
**問題:** 系統運行緩慢或崩潰
|
| 122 |
+
|
| 123 |
+
**解決方案:**
|
| 124 |
+
- Llama 3.2:3B 需要約 2GB RAM
|
| 125 |
+
- 確保系統有足夠的可用記憶體(推薦至少 8GB)
|
| 126 |
+
- 這個模型已經很輕量,適合 16GB 記憶體的系統
|
| 127 |
+
|
| 128 |
+
## 📊 模型比較
|
| 129 |
+
|
| 130 |
+
| 模型 | 大小 | 記憶體需求 | 特點 |
|
| 131 |
+
|------|------|-----------|------|
|
| 132 |
+
| llama3.2:3b | ~2GB | ~4GB | 輕量高效,適合 16GB 記憶體系統,Meta 開源 |
|
| 133 |
+
| deepseek-r1:7b | ~4.7GB | ~8GB | 優秀的推理能力,適合數學、編程 |
|
| 134 |
+
| qwen2.5:7b | ~4.5GB | ~8GB | 通用能力強,中英文支援好 |
|
| 135 |
+
| llama3.1:8b | ~4.6GB | ~8GB | Meta 開源,性能穩定 |
|
| 136 |
+
| mistral:7b | ~4.1GB | ~7GB | 速度快,效率高 |
|
| 137 |
+
|
| 138 |
+
## 💡 性能優化建議
|
| 139 |
+
|
| 140 |
+
1. **優先使用 Groq API**:如果可用,Groq API 速度最快
|
| 141 |
+
2. **Ollama 作為備援**:當 Groq 不可用時,Ollama 提供良好的本地推理
|
| 142 |
+
3. **MLX 作為最後備援**:在 Apple Silicon 上,MLX 模型有硬體優化
|
| 143 |
+
|
| 144 |
+
## 📚 相關資源
|
| 145 |
+
|
| 146 |
+
- [Ollama 官方文檔](https://ollama.com/docs)
|
| 147 |
+
- [Llama 3.2 模型資訊](https://ollama.com/library/llama3.2)
|
| 148 |
+
- [LangChain Ollama 整合](https://python.langchain.com/docs/integrations/llms/ollama)
|
| 149 |
+
|
| 150 |
+
---
|
| 151 |
+
|
| 152 |
+
**注意**:首次使用時,Ollama 會下載模型文件,這可能需要一些時間,請耐心等待。
|
| 153 |
+
|
PRIVATE_FILE_RAG_GUIDE.md
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 私有檔案 RAG 功能使用指南
|
| 2 |
+
|
| 3 |
+
## 總覽
|
| 4 |
+
|
| 5 |
+
本功能整合了 `Learn_RAG` 專案的 RAG 系統,讓使用者可以上傳私有檔案(如 PDF、DOCX、TXT),並基於這些檔案內容進行智慧問答。系統採用了先進的混合檢索與 LLM 技術,以提供準確、相關的回答。
|
| 6 |
+
|
| 7 |
+
## 功能特色
|
| 8 |
+
|
| 9 |
+
- ✅ **支援多種檔案格式**:PDF、DOCX、DOC、TXT。
|
| 10 |
+
- ✅ **支援多檔案上傳**:可一次上傳並處理多個檔案。
|
| 11 |
+
- ✅ **混合檢索**:結合 BM25(關鍵字檢索)與向量檢索(語義檢索),大幅提升檢索準確度。
|
| 12 |
+
- ✅ **可選重排序**:使用 BGE Reranker 模型進一步優化檢索結果的相關性。
|
| 13 |
+
- ✅ **兩種分塊模式**:
|
| 14 |
+
- **語義分塊(推薦)**:智慧切分文件,確保語義完整性,不會在句子中間斷開,提升檢索品質。
|
| 15 |
+
- **字元分塊**:依固定字數切分,處理速度快,適合快速測試。
|
| 16 |
+
- ✅ **智慧回答生成**:
|
| 17 |
+
- **LLM 自動切換策略**:系統會自動根據可用性選擇最佳的 LLM,優先順序為:**Groq API > Ollama > MLX 本地模型**。
|
| 18 |
+
- **自動化提示工程**:自動偵測檔案類型(如學術論文、履歷、通用文件)以調整提問風格,生成更貼切的回答。
|
| 19 |
+
- ✅ **支援中英文問答**。
|
| 20 |
+
|
| 21 |
+
## 使用方法
|
| 22 |
+
|
| 23 |
+
### 1. 準備工作
|
| 24 |
+
|
| 25 |
+
確保 `Learn_RAG` 專案與本專案 (`Deep_Agentic_AI_Tool`) 位於**同一個父目錄**下。正確的目錄結構應如下:
|
| 26 |
+
|
| 27 |
+
```
|
| 28 |
+
/some_parent_directory/
|
| 29 |
+
├─── Deep_Agentic_AI_Tool/ (本專案)
|
| 30 |
+
└─── Learn_RAG/
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
此外,請確保 `Learn_RAG` 的依賴已安裝。您可以在 `Learn_RAG` 目錄下執行 `uv sync` 或 `pip install -r requirements.txt`。
|
| 34 |
+
|
| 35 |
+
### 2. 啟動系統
|
| 36 |
+
|
| 37 |
+
執行主程式以啟動 Gradio 網頁介面:
|
| 38 |
+
```bash
|
| 39 |
+
python Deep_Agent_Gradio_RAG_localLLM_main.py
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
### 3. 使用步驟
|
| 43 |
+
|
| 44 |
+
1. **開啟介面**:在瀏覽器中開啟 Gradio 介面(預設為 `http://0.0.0.0:7860`),並點擊 **"📚 Private File RAG"** 標籤頁。
|
| 45 |
+
|
| 46 |
+
2. **上傳檔案**:點擊或拖曳檔案至 **"📁 上傳檔案"** 區域,可選擇一個或多個 PDF、DOCX 或 TXT 檔案。
|
| 47 |
+
|
| 48 |
+
3. **設定分塊模式**:
|
| 49 |
+
- **使用語義分塊(推薦)**:勾選此選項以獲得最佳的檢索品質。處理時間較長,但效果最好。
|
| 50 |
+
- **調整參數(可選)**:介面提供了對兩種分塊模式的進階參數調整,您可以根據需求調整,或直接使用已優化的預設值。
|
| 51 |
+
|
| 52 |
+
4. **處理檔案**:點擊 **"📝 處理檔案"** 按鈕。系統會根據您的設定進行分塊、建立索引並初始化 RAG 系統。請等待處理狀態顯示 "✅ 文件處理完成"。
|
| 53 |
+
- **注意**:首次使用時,系統需要下載 Embedding 模型,可能需要數分鐘時間,請耐心等候。
|
| 54 |
+
|
| 55 |
+
5. **提出問題**:在 **"❓ 請輸入您的問題"** 輸入框中,輸入您想詢問關於檔案內容的問題。
|
| 56 |
+
|
| 57 |
+
6. **調整查詢選項**:
|
| 58 |
+
- **返回結果數量**:可調整檢索到的相關文件片段數量(預設為 3)。
|
| 59 |
+
- **使用 LLM 生成回答**:勾選此項,AI 會總結檢索到的內容並生成流暢的回答。若取消勾選,則僅顯示原始的文件片段。
|
| 60 |
+
|
| 61 |
+
7. **執行查詢**:點擊 **"🔍 查詢"** 按鈕。
|
| 62 |
+
|
| 63 |
+
8. **檢視結果**:
|
| 64 |
+
- **💬 AI 回答**:顯示由 LLM 生成的最終回答。
|
| 65 |
+
- **📄 檢索到的文件片段**:顯示用於生成回答的原始文件內容、來源及相關性分數,方便您查證。
|
| 66 |
+
|
| 67 |
+
9. **清除與重置**:點擊 **"🗑️ 清除"** 按鈕可重設當前會話,讓您重新上傳檔案。
|
| 68 |
+
|
| 69 |
+
## 技術細節
|
| 70 |
+
|
| 71 |
+
### LLM 使用策略
|
| 72 |
+
|
| 73 |
+
本系統採用彈性的 LLM 調度策略,無需手動設定:
|
| 74 |
+
1. 🥇 **Groq API**:若您在環境變數中設定了 `GROQ_API_KEY`,系統會優先使用速度極快的 Groq API。
|
| 75 |
+
2. 🥈 **Ollama**:若 Groq API 不可用或額度用盡,系統會自動切換至本地運行的 Ollama 模型(如 Llama3.2)。
|
| 76 |
+
3. 🥉 **MLX**:若前兩者皆不可用,系統會使用 Apple MLX 在本地運行模型(如 Qwen2.5)作為最終備案。
|
| 77 |
+
|
| 78 |
+
### 分塊模式詳解
|
| 79 |
+
|
| 80 |
+
- **字元分塊**(預設 `chunk_size: 500`, `chunk_overlap: 100`):
|
| 81 |
+
- **優點**:處理速度快。
|
| 82 |
+
- **適用場景**:快速測試、對語義完整性要求不高的文件。
|
| 83 |
+
- **參數說明**:
|
| 84 |
+
- `分塊大小`:每個區塊的字元數。較小值粒度更細,較大值上下文更完整。
|
| 85 |
+
- `分塊重疊`:相鄰區塊間重疊的字元數,有助於保持上下文連貫。
|
| 86 |
+
|
| 87 |
+
- **語義分塊**(預設 `threshold: 1.0`, `min_chunk_size: 100`):
|
| 88 |
+
- **優點**:根據語義邊界切分,能保持句子和段落的完整性,檢索品質更高。
|
| 89 |
+
- **適用場景**:專業文件、報告、論文等需要精準理解上下文的場景。
|
| 90 |
+
- **參數說明**:
|
| 91 |
+
- `語義分塊閾值`:控制分塊的敏感度。數值越小,分塊越細。建議值為 0.8-1.2。
|
| 92 |
+
- `最小分塊大小`:過���的區塊會被合併,以避免碎片化。
|
| 93 |
+
|
| 94 |
+
### 檢索系統
|
| 95 |
+
|
| 96 |
+
1. **Embedding 模型**:使用 `sentence-transformers/all-MiniLM-L6-v2` 將文字轉換為向量,此模型輕量且高效。
|
| 97 |
+
2. **向量資料庫**:使用 `ChromaDB` 儲存並索引向量,資料庫會持久化儲存於 `./chroma_db_private` 目錄。
|
| 98 |
+
3. **混合檢索**:結合 `BM25`(基於關鍵字)和 `向量檢索`(基於語義),並透過 `RRF` (Reciprocal Rank Fusion) 演算法融合結果,兼顧關鍵字匹配和語義相似度。
|
| 99 |
+
4. **重排序器(Reranker)**:使用 `BAAI/bge-reranker-base` 模型對混合檢索的結果進行二次排序,將最相關的片段排在最前面,極大化提升了最終答案的品質。
|
| 100 |
+
|
| 101 |
+
## 常見問題
|
| 102 |
+
|
| 103 |
+
### Q: 處理檔案時出錯或提示 `Learn_RAG` 模組不可用?
|
| 104 |
+
**A:** 請檢查:
|
| 105 |
+
1. **專案位置**:確保 `Learn_RAG` 專案目錄與 `Deep_Agentic_AI_Tool` 位於同一父目錄下。
|
| 106 |
+
2. **依賴安裝**:確認您已安裝 `Learn_RAG` 的所有 Python 依賴。最簡單的方式是進入 `Learn_RAG` 目錄並執行 `uv sync`。
|
| 107 |
+
3. **檔案格式**:確認您上傳的是支援的格式(PDF, DOCX, DOC, TXT)且檔案未損毀。
|
| 108 |
+
|
| 109 |
+
### Q: AI 無法生成回答,或回答很慢?
|
| 110 |
+
**A:** 請檢查:
|
| 111 |
+
1. **LLM 服務**:如果您想使用 Ollama,請確保 Ollama 服務正在本地運行(可透過終端機執行 `ollama serve`)。
|
| 112 |
+
2. **模型下載**:確認 Ollama 需要的模型已經下載(如 `ollama pull llama3.2:3b`)。
|
| 113 |
+
3. **LLM 狀態**:系統會自動從 Groq 切換至本地模型。若切換至 MLX,處理速度會較慢,請耐心等待。若不需生成式回答,可取消勾選 "使用 LLM 生成回答" 以直接查看檢索結果。
|
| 114 |
+
|
| 115 |
+
### Q: "清除" 按鈕的功能是什麼?
|
| 116 |
+
**A:** "清除" 按鈕會重設當前 Gradio 會話中的 RAG 系統,清空已上傳的檔案和記憶體中的索引。這讓您可以重新上傳並處理新的一批檔案。
|
| 117 |
+
**注意**:此按鈕**不會**刪除磁碟上 `./chroma_db_private` 目錄中持久化的向量資料庫。若要完全清空所有資料,您需要手動刪除該目錄。
|
| 118 |
+
|
| 119 |
+
### Q: 處理檔案速度很慢?
|
| 120 |
+
**A:**
|
| 121 |
+
1. 首次使用時,系統需要下載數百 MB 的 Embedding 模型,請耐心等待。
|
| 122 |
+
2. 語義分塊模式會進行複雜的計算,處理時間比字元分塊長,但效果更好。
|
| 123 |
+
3. 檔案越大、數量越多,處理時間越長。
|
| 124 |
+
|
| 125 |
+
## 依賴項
|
| 126 |
+
|
| 127 |
+
- `Learn_RAG` 專案及其所有依賴項。
|
| 128 |
+
- `langchain-community`, `sentence-transformers`, `chromadb`, `rank_bm25`, `pypdf`, `docx2txt` 等。
|
| 129 |
+
- `ollama` (若使用本地 Ollama LLM)。
|
README.md
CHANGED
|
@@ -66,9 +66,29 @@
|
|
| 66 |
# 使用 uv(推薦)
|
| 67 |
uv sync
|
| 68 |
|
|
|
|
| 69 |
# 或使用 pip
|
| 70 |
pip install -e .
|
| 71 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
### 2. 環境變數配置
|
| 74 |
|
|
@@ -220,6 +240,7 @@ Deep_Agentic_AI_Tool/
|
|
| 220 |
3. 在 `get_tools_list()` 中添加工具
|
| 221 |
4. 代理會自動發現並使用新工具
|
| 222 |
|
|
|
|
| 223 |
### 修改代理邏輯
|
| 224 |
|
| 225 |
- **規劃邏輯**:編輯 `deep_agent_rag/agents/planner.py`
|
|
@@ -231,6 +252,41 @@ Deep_Agentic_AI_Tool/
|
|
| 231 |
編輯 `deep_agent_rag/ui/gradio_interface.py` 修改 Web 界面。
|
| 232 |
|
| 233 |
**詳細開發指南請參考:[系統架構](ARCHITECTURE.md#開發指南)**
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
|
| 235 |
## 📦 主要依賴
|
| 236 |
|
|
@@ -288,14 +344,75 @@ Deep_Agentic_AI_Tool/
|
|
| 288 |
|
| 289 |
## 📧 聯絡
|
| 290 |
|
|
|
|
| 291 |
[添加聯絡資訊]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
|
| 293 |
## 🙏 致謝
|
| 294 |
|
|
|
|
| 295 |
- **LangChain & LangGraph**:優秀的代理框架
|
| 296 |
- **MLX Team**:高效的本地模型推理
|
| 297 |
- **Qwen Team**:Qwen2.5 模型
|
| 298 |
- **Jina AI**:嵌入模型
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
|
| 300 |
---
|
| 301 |
|
|
|
|
| 66 |
# 使用 uv(推薦)
|
| 67 |
uv sync
|
| 68 |
|
| 69 |
+
<<<<<<< HEAD
|
| 70 |
# 或使用 pip
|
| 71 |
pip install -e .
|
| 72 |
```
|
| 73 |
+
=======
|
| 74 |
+
3. **Set up environment variables** (create a `.env` file in the root directory):
|
| 75 |
+
```env
|
| 76 |
+
# Optional: Groq API (for faster inference)
|
| 77 |
+
GROQ_API_KEY=your_groq_api_key_here
|
| 78 |
+
|
| 79 |
+
# Optional: Ollama (for local inference with Llama 3.2 or other models)
|
| 80 |
+
USE_OLLAMA=true
|
| 81 |
+
OLLAMA_BASE_URL=http://localhost:11434
|
| 82 |
+
OLLAMA_MODEL=llama3.2:3b
|
| 83 |
+
|
| 84 |
+
# Optional: Tavily API (for web search)
|
| 85 |
+
TAVILY_API_KEY=your_tavily_api_key_here
|
| 86 |
+
|
| 87 |
+
# Optional: Gmail API credentials
|
| 88 |
+
GMAIL_CREDENTIALS_FILE=credentials.json
|
| 89 |
+
GMAIL_TOKEN_FILE=token.json
|
| 90 |
+
```
|
| 91 |
+
>>>>>>> 5beccbe9dfa0ef53e4123976ad54e2f1c28b72f8
|
| 92 |
|
| 93 |
### 2. 環境變數配置
|
| 94 |
|
|
|
|
| 240 |
3. 在 `get_tools_list()` 中添加工具
|
| 241 |
4. 代理會自動發現並使用新工具
|
| 242 |
|
| 243 |
+
<<<<<<< HEAD
|
| 244 |
### 修改代理邏輯
|
| 245 |
|
| 246 |
- **規劃邏輯**:編輯 `deep_agent_rag/agents/planner.py`
|
|
|
|
| 252 |
編輯 `deep_agent_rag/ui/gradio_interface.py` 修改 Web 界面。
|
| 253 |
|
| 254 |
**詳細開發指南請參考:[系統架構](ARCHITECTURE.md#開發指南)**
|
| 255 |
+
=======
|
| 256 |
+
The system supports multiple LLM backends with automatic fallback (priority order):
|
| 257 |
+
|
| 258 |
+
1. **Primary**: Groq API (fastest, requires API key)
|
| 259 |
+
- Model: `llama-3.3-70b-versatile`
|
| 260 |
+
- Automatically used if `GROQ_API_KEY` is set
|
| 261 |
+
|
| 262 |
+
2. **Secondary**: Ollama (local inference, excellent reasoning capabilities)
|
| 263 |
+
- Default Model: `llama3.2:3b` (Llama 3.2 3B)
|
| 264 |
+
- Requires Ollama installed and model downloaded
|
| 265 |
+
- Enable with `USE_OLLAMA=true` in `.env`
|
| 266 |
+
- Lightweight and efficient, suitable for 16GB memory systems
|
| 267 |
+
- Automatically used when Groq API is unavailable or quota exhausted
|
| 268 |
+
|
| 269 |
+
3. **Fallback**: Local MLX Model (privacy-preserving, no API key needed)
|
| 270 |
+
- Model: `mlx-community/Qwen2.5-Coder-7B-Instruct-4bit`
|
| 271 |
+
- Automatically used when both Groq API and Ollama are unavailable
|
| 272 |
+
|
| 273 |
+
The system automatically switches between backends based on availability.
|
| 274 |
+
|
| 275 |
+
**Setting up Ollama:**
|
| 276 |
+
```bash
|
| 277 |
+
# Install Ollama (if not already installed)
|
| 278 |
+
# macOS: brew install ollama
|
| 279 |
+
# Or download from https://ollama.com
|
| 280 |
+
|
| 281 |
+
# Download Llama 3.2 model
|
| 282 |
+
ollama pull llama3.2:3b
|
| 283 |
+
|
| 284 |
+
# Start Ollama service (usually runs automatically)
|
| 285 |
+
ollama serve
|
| 286 |
+
```
|
| 287 |
+
|
| 288 |
+
## ⚙️ Configuration
|
| 289 |
+
>>>>>>> 5beccbe9dfa0ef53e4123976ad54e2f1c28b72f8
|
| 290 |
|
| 291 |
## 📦 主要依賴
|
| 292 |
|
|
|
|
| 344 |
|
| 345 |
## 📧 聯絡
|
| 346 |
|
| 347 |
+
<<<<<<< HEAD
|
| 348 |
[添加聯絡資訊]
|
| 349 |
+
=======
|
| 350 |
+
- **LangChain**: Agent framework and tool integration
|
| 351 |
+
- **LangGraph**: Agent orchestration and workflow management
|
| 352 |
+
- **MLX/MLX-LM**: Local model inference (Apple Silicon optimized)
|
| 353 |
+
- **LangChain Ollama**: Ollama integration for local models
|
| 354 |
+
- **Gradio**: Web interface
|
| 355 |
+
- **ChromaDB**: Vector database for RAG
|
| 356 |
+
- **Tavily**: Web search API
|
| 357 |
+
- **yfinance**: Stock data retrieval
|
| 358 |
+
- **Google API Client**: Gmail API integration
|
| 359 |
+
>>>>>>> 5beccbe9dfa0ef53e4123976ad54e2f1c28b72f8
|
| 360 |
|
| 361 |
## 🙏 致謝
|
| 362 |
|
| 363 |
+
<<<<<<< HEAD
|
| 364 |
- **LangChain & LangGraph**:優秀的代理框架
|
| 365 |
- **MLX Team**:高效的本地模型推理
|
| 366 |
- **Qwen Team**:Qwen2.5 模型
|
| 367 |
- **Jina AI**:嵌入模型
|
| 368 |
+
=======
|
| 369 |
+
### MLX Model Issues
|
| 370 |
+
|
| 371 |
+
- **Model not loading**: Ensure you have sufficient disk space and memory
|
| 372 |
+
- **Slow inference**: This is normal for local models. Consider using Groq API for faster results
|
| 373 |
+
|
| 374 |
+
### Groq API Issues
|
| 375 |
+
|
| 376 |
+
- **Quota exhausted**: The system automatically falls back to Ollama (if enabled) or local MLX model
|
| 377 |
+
- **API errors**: Check your `GROQ_API_KEY` in `.env` file
|
| 378 |
+
|
| 379 |
+
### Ollama Issues
|
| 380 |
+
|
| 381 |
+
- **Ollama not starting**: Ensure Ollama service is running (`ollama serve`)
|
| 382 |
+
- **Model not found**: Download the model first (`ollama pull llama3.2:3b`)
|
| 383 |
+
- **Connection errors**: Check `OLLAMA_BASE_URL` in `.env` (default: `http://localhost:11434`)
|
| 384 |
+
- **Memory issues**: Llama 3.2:3B requires ~2GB RAM, suitable for systems with 16GB memory
|
| 385 |
+
|
| 386 |
+
### RAG System Issues
|
| 387 |
+
|
| 388 |
+
- **PDF not found**: Ensure the PDF file exists at the path specified in `config.py`
|
| 389 |
+
- **Embedding model errors**: The system will attempt to re-download the model if cache is corrupted
|
| 390 |
+
|
| 391 |
+
### Gmail API Issues
|
| 392 |
+
|
| 393 |
+
- **Authorization errors**: Delete `token.json` and re-authorize
|
| 394 |
+
- **Credentials not found**: Ensure `credentials.json` is in the project root
|
| 395 |
+
- See `GMAIL_API_SETUP.md` for detailed setup instructions
|
| 396 |
+
|
| 397 |
+
## 📝 License
|
| 398 |
+
|
| 399 |
+
[Add your license information here]
|
| 400 |
+
|
| 401 |
+
## 🤝 Contributing
|
| 402 |
+
|
| 403 |
+
[Add contribution guidelines here]
|
| 404 |
+
|
| 405 |
+
## 📧 Contact
|
| 406 |
+
|
| 407 |
+
[Add contact information here]
|
| 408 |
+
|
| 409 |
+
## 🙏 Acknowledgments
|
| 410 |
+
|
| 411 |
+
- **LangChain & LangGraph**: For the excellent agent framework
|
| 412 |
+
- **MLX Team**: For efficient local model inference
|
| 413 |
+
- **Qwen Team**: For the Qwen2.5 model
|
| 414 |
+
- **Jina AI**: For the embedding model
|
| 415 |
+
>>>>>>> 5beccbe9dfa0ef53e4123976ad54e2f1c28b72f8
|
| 416 |
|
| 417 |
---
|
| 418 |
|
deep_agent_rag/config.py
CHANGED
|
@@ -49,6 +49,13 @@ GROQ_MAX_TOKENS = 2048
|
|
| 49 |
GROQ_TEMPERATURE = 0.7
|
| 50 |
USE_GROQ_FIRST = True # 是否优先使用 Groq API
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
# Email 配置 - 使用 Gmail API
|
| 53 |
EMAIL_SENDER = "matthuang46@gmail.com"
|
| 54 |
# Gmail API 配置
|
|
|
|
| 49 |
GROQ_TEMPERATURE = 0.7
|
| 50 |
USE_GROQ_FIRST = True # 是否优先使用 Groq API
|
| 51 |
|
| 52 |
+
# Ollama 配置
|
| 53 |
+
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
|
| 54 |
+
OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", "llama3.2:3b") # Llama 3.2 3B
|
| 55 |
+
OLLAMA_MAX_TOKENS = 2048
|
| 56 |
+
OLLAMA_TEMPERATURE = 0.7
|
| 57 |
+
USE_OLLAMA = os.getenv("USE_OLLAMA", "false").lower() == "true" # 是否啟用 Ollama
|
| 58 |
+
|
| 59 |
# Email 配置 - 使用 Gmail API
|
| 60 |
EMAIL_SENDER = "matthuang46@gmail.com"
|
| 61 |
# Gmail API 配置
|
deep_agent_rag/rag/adaptive_rag_selector.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
自适应 RAG 方法选择器
|
| 3 |
+
根据查询和文件特征自动选择最佳的 RAG 方法
|
| 4 |
+
"""
|
| 5 |
+
from typing import Dict, List, Optional
|
| 6 |
+
from enum import Enum
|
| 7 |
+
import re
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class RAGMethod(Enum):
|
| 14 |
+
"""可用的 RAG 方法"""
|
| 15 |
+
BASIC = "basic" # 基础 RAG(当前使用的)
|
| 16 |
+
SUBQUERY = "subquery" # 子查询分解
|
| 17 |
+
HYDE = "hyde" # 假设文档嵌入
|
| 18 |
+
STEP_BACK = "step_back" # 后退推理
|
| 19 |
+
HYBRID_SUBQUERY_HYDE = "hybrid_subquery_hyde" # 混合子查询+HyDE
|
| 20 |
+
TRIPLE_HYBRID = "triple_hybrid" # 三重混合
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class QueryComplexity(Enum):
|
| 24 |
+
"""查询复杂度"""
|
| 25 |
+
SIMPLE = "simple" # 简单查询(单问题,短句)
|
| 26 |
+
MODERATE = "moderate" # 中等复杂度(包含多个概念)
|
| 27 |
+
COMPLEX = "complex" # 复杂查询(多部分问题,需要分解)
|
| 28 |
+
VERY_COMPLEX = "very_complex" # 非常复杂(多个相关问题)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class QueryType(Enum):
|
| 32 |
+
"""查询类型"""
|
| 33 |
+
FACTUAL = "factual" # 事实性查询("什么是X")
|
| 34 |
+
CONCEPTUAL = "conceptual" # 概念性查询("如何理解X")
|
| 35 |
+
COMPARATIVE = "comparative" # 比较性查询("X和Y的区别")
|
| 36 |
+
PRINCIPLE = "principle" # 原理性查询("X的工作原理")
|
| 37 |
+
MULTI_ASPECT = "multi_aspect" # 多面向查询(包含多个问题)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class AdaptiveRAGSelector:
|
| 41 |
+
"""
|
| 42 |
+
自适应 RAG 方法选择器
|
| 43 |
+
|
| 44 |
+
根据以下特征选择最佳 RAG 方法:
|
| 45 |
+
1. 查询复杂度
|
| 46 |
+
2. 查询类型
|
| 47 |
+
3. 文件数量和类型
|
| 48 |
+
4. 文档复杂度
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(self):
|
| 52 |
+
"""初始化选择器"""
|
| 53 |
+
pass
|
| 54 |
+
|
| 55 |
+
def analyze_query(self, query: str) -> Dict:
|
| 56 |
+
"""
|
| 57 |
+
分析查询特征
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
query: 用户查询问题
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
包含查询特征的字典
|
| 64 |
+
"""
|
| 65 |
+
query_lower = query.lower()
|
| 66 |
+
query_len = len(query)
|
| 67 |
+
word_count = len(query.split())
|
| 68 |
+
|
| 69 |
+
# 检测查询复杂度
|
| 70 |
+
complexity = self._detect_complexity(query, word_count)
|
| 71 |
+
|
| 72 |
+
# 检测查询类型
|
| 73 |
+
query_type = self._detect_query_type(query, query_lower)
|
| 74 |
+
|
| 75 |
+
# 检测是否包含多个问题
|
| 76 |
+
question_count = query.count('?') + query.count('?')
|
| 77 |
+
has_multiple_questions = question_count > 1
|
| 78 |
+
|
| 79 |
+
# 检测是否包含比较性词汇
|
| 80 |
+
comparison_keywords = ['vs', 'versus', 'difference', '区别', '比较', 'compare', '对比', '和', 'and', '与']
|
| 81 |
+
is_comparative = any(kw in query_lower for kw in comparison_keywords)
|
| 82 |
+
|
| 83 |
+
# 检测是否包含专业术语
|
| 84 |
+
technical_indicators = [
|
| 85 |
+
'原理', 'mechanism', 'algorithm', 'architecture', 'model', 'system',
|
| 86 |
+
'原理', '机制', '算法', '架构', '模型', '系统', '方法', 'method',
|
| 87 |
+
'如何工作', 'how does', 'how do', 'work', 'function'
|
| 88 |
+
]
|
| 89 |
+
has_technical_terms = any(ind in query_lower for ind in technical_indicators)
|
| 90 |
+
|
| 91 |
+
# 检测是否包含"为什么"、"如何"等需要解释的词汇
|
| 92 |
+
explanation_keywords = ['为什么', 'why', '如何', 'how', 'explain', '解释', '说明']
|
| 93 |
+
needs_explanation = any(kw in query_lower for kw in explanation_keywords)
|
| 94 |
+
|
| 95 |
+
return {
|
| 96 |
+
'complexity': complexity,
|
| 97 |
+
'type': query_type,
|
| 98 |
+
'word_count': word_count,
|
| 99 |
+
'length': query_len,
|
| 100 |
+
'has_multiple_questions': has_multiple_questions,
|
| 101 |
+
'is_comparative': is_comparative,
|
| 102 |
+
'has_technical_terms': has_technical_terms,
|
| 103 |
+
'needs_explanation': needs_explanation,
|
| 104 |
+
'question_count': question_count
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
def _detect_complexity(self, query: str, word_count: int) -> QueryComplexity:
|
| 108 |
+
"""检测查询复杂度"""
|
| 109 |
+
# 简单查询:短句,单问题
|
| 110 |
+
if word_count <= 10 and query.count('?') + query.count('?') <= 1:
|
| 111 |
+
return QueryComplexity.SIMPLE
|
| 112 |
+
|
| 113 |
+
# 中等复杂度:中等长度,可能包含多个概念
|
| 114 |
+
if word_count <= 25:
|
| 115 |
+
return QueryComplexity.MODERATE
|
| 116 |
+
|
| 117 |
+
# 复杂查询:长句,多个问题或概念
|
| 118 |
+
if word_count <= 50:
|
| 119 |
+
return QueryComplexity.COMPLEX
|
| 120 |
+
|
| 121 |
+
# 非常复杂:很长,多个问题
|
| 122 |
+
return QueryComplexity.VERY_COMPLEX
|
| 123 |
+
|
| 124 |
+
def _detect_query_type(self, query: str, query_lower: str) -> QueryType:
|
| 125 |
+
"""检测查询类型"""
|
| 126 |
+
# 比较性查询
|
| 127 |
+
if any(kw in query_lower for kw in ['vs', 'versus', 'difference', '区别', '比较', 'compare', '对比', '和', 'and', '与']):
|
| 128 |
+
return QueryType.COMPARATIVE
|
| 129 |
+
|
| 130 |
+
# 原理性查询
|
| 131 |
+
if any(kw in query_lower for kw in ['原理', 'principle', 'how does', 'how do', 'mechanism', '如何工作', '工作原理']):
|
| 132 |
+
return QueryType.PRINCIPLE
|
| 133 |
+
|
| 134 |
+
# 概念性查询
|
| 135 |
+
if any(kw in query_lower for kw in ['什么是', 'what is', '理解', 'understand', 'explain', '解释']):
|
| 136 |
+
return QueryType.CONCEPTUAL
|
| 137 |
+
|
| 138 |
+
# 多面向查询
|
| 139 |
+
if query.count('?') + query.count('?') > 1:
|
| 140 |
+
return QueryType.MULTI_ASPECT
|
| 141 |
+
|
| 142 |
+
# 默认:事实性查询
|
| 143 |
+
return QueryType.FACTUAL
|
| 144 |
+
|
| 145 |
+
def analyze_files(self, file_paths: List[str], documents: Optional[List[Dict]] = None) -> Dict:
|
| 146 |
+
"""
|
| 147 |
+
分析文件特征
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
file_paths: 文件路径列表
|
| 151 |
+
documents: 文档列表(可选,如果已处理)
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
包含文件特征的字典
|
| 155 |
+
"""
|
| 156 |
+
file_count = len(file_paths)
|
| 157 |
+
|
| 158 |
+
# 检测文件类型
|
| 159 |
+
file_types = []
|
| 160 |
+
for path in file_paths:
|
| 161 |
+
if path.endswith('.pdf'):
|
| 162 |
+
file_types.append('pdf')
|
| 163 |
+
elif path.endswith(('.docx', '.doc')):
|
| 164 |
+
file_types.append('docx')
|
| 165 |
+
else:
|
| 166 |
+
file_types.append('txt')
|
| 167 |
+
|
| 168 |
+
# 分析文档复杂度(如果有文档)
|
| 169 |
+
total_chunks = len(documents) if documents else 0
|
| 170 |
+
avg_chunk_size = 0
|
| 171 |
+
if documents:
|
| 172 |
+
chunk_sizes = [len(doc.get('content', '')) for doc in documents]
|
| 173 |
+
avg_chunk_size = sum(chunk_sizes) / len(chunk_sizes) if chunk_sizes else 0
|
| 174 |
+
|
| 175 |
+
# 检测是否为学术论文(基于文件名或内容)
|
| 176 |
+
is_academic = any('paper' in path.lower() or 'arxiv' in path.lower() or
|
| 177 |
+
path.endswith('.pdf') for path in file_paths)
|
| 178 |
+
|
| 179 |
+
return {
|
| 180 |
+
'file_count': file_count,
|
| 181 |
+
'file_types': file_types,
|
| 182 |
+
'total_chunks': total_chunks,
|
| 183 |
+
'avg_chunk_size': avg_chunk_size,
|
| 184 |
+
'is_academic': is_academic,
|
| 185 |
+
'is_single_file': file_count == 1,
|
| 186 |
+
'is_multi_file': file_count > 1
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
def select_best_method(
|
| 190 |
+
self,
|
| 191 |
+
query_features: Dict,
|
| 192 |
+
file_features: Dict,
|
| 193 |
+
enable_advanced: bool = True
|
| 194 |
+
) -> RAGMethod:
|
| 195 |
+
"""
|
| 196 |
+
根据特征选择最佳 RAG 方法
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
query_features: 查询特征(来自 analyze_query)
|
| 200 |
+
file_features: 文件特征(来自 analyze_files)
|
| 201 |
+
enable_advanced: 是否启用高级方法(如果 False,只使用基础方法)
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
选择的 RAG 方法
|
| 205 |
+
"""
|
| 206 |
+
if not enable_advanced:
|
| 207 |
+
return RAGMethod.BASIC
|
| 208 |
+
|
| 209 |
+
complexity = query_features['complexity']
|
| 210 |
+
query_type = query_features['type']
|
| 211 |
+
has_multiple_questions = query_features['has_multiple_questions']
|
| 212 |
+
is_comparative = query_features['is_comparative']
|
| 213 |
+
has_technical_terms = query_features['has_technical_terms']
|
| 214 |
+
needs_explanation = query_features['needs_explanation']
|
| 215 |
+
file_count = file_features['file_count']
|
| 216 |
+
is_multi_file = file_features['is_multi_file']
|
| 217 |
+
|
| 218 |
+
# 决策树
|
| 219 |
+
|
| 220 |
+
# 1. 非常复杂的查询 + 多文件 → Triple Hybrid(最强)
|
| 221 |
+
if complexity == QueryComplexity.VERY_COMPLEX and is_multi_file:
|
| 222 |
+
return RAGMethod.TRIPLE_HYBRID
|
| 223 |
+
|
| 224 |
+
# 2. 复杂查询 + 多问题 → SubQuery 或 Hybrid Subquery+HyDE
|
| 225 |
+
if complexity in [QueryComplexity.COMPLEX, QueryComplexity.VERY_COMPLEX]:
|
| 226 |
+
if has_multiple_questions or query_type == QueryType.MULTI_ASPECT:
|
| 227 |
+
if is_multi_file:
|
| 228 |
+
return RAGMethod.HYBRID_SUBQUERY_HYDE
|
| 229 |
+
else:
|
| 230 |
+
return RAGMethod.SUBQUERY
|
| 231 |
+
|
| 232 |
+
# 3. 原理性查询 → Step-back(需要背景知识)
|
| 233 |
+
if query_type == QueryType.PRINCIPLE:
|
| 234 |
+
if complexity in [QueryComplexity.MODERATE, QueryComplexity.COMPLEX]:
|
| 235 |
+
return RAGMethod.STEP_BACK
|
| 236 |
+
|
| 237 |
+
# 4. 专业术语查询 → HyDE(生成假设文档)
|
| 238 |
+
if has_technical_terms and complexity in [QueryComplexity.MODERATE, QueryComplexity.COMPLEX]:
|
| 239 |
+
return RAGMethod.HYDE
|
| 240 |
+
|
| 241 |
+
# 5. 比较性查询 + 多文件 → SubQuery(需要分别检索)
|
| 242 |
+
if is_comparative and is_multi_file:
|
| 243 |
+
return RAGMethod.SUBQUERY
|
| 244 |
+
|
| 245 |
+
# 6. 中等复杂度 + 多文件 → Hybrid Subquery+HyDE
|
| 246 |
+
if complexity == QueryComplexity.MODERATE and is_multi_file:
|
| 247 |
+
return RAGMethod.HYBRID_SUBQUERY_HYDE
|
| 248 |
+
|
| 249 |
+
# 7. 简单查询 → 基础 RAG 或 HyDE
|
| 250 |
+
if complexity == QueryComplexity.SIMPLE:
|
| 251 |
+
if has_technical_terms:
|
| 252 |
+
return RAGMethod.HYDE
|
| 253 |
+
else:
|
| 254 |
+
return RAGMethod.BASIC
|
| 255 |
+
|
| 256 |
+
# 8. 需要解释的查询 → Step-back(提供背景知识)
|
| 257 |
+
if needs_explanation and complexity == QueryComplexity.MODERATE:
|
| 258 |
+
return RAGMethod.STEP_BACK
|
| 259 |
+
|
| 260 |
+
# 9. 默认:中等复杂度使用 Step-back
|
| 261 |
+
if complexity == QueryComplexity.MODERATE:
|
| 262 |
+
return RAGMethod.STEP_BACK
|
| 263 |
+
|
| 264 |
+
# 10. 默认:复杂查询使用 SubQuery
|
| 265 |
+
return RAGMethod.SUBQUERY
|
| 266 |
+
|
| 267 |
+
def get_method_reason(self, method: RAGMethod, query_features: Dict, file_features: Dict) -> str:
|
| 268 |
+
"""
|
| 269 |
+
获取选择该方法的理由
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
method: 选择的 RAG 方法
|
| 273 |
+
query_features: 查询特征
|
| 274 |
+
file_features: 文件特征
|
| 275 |
+
|
| 276 |
+
Returns:
|
| 277 |
+
选择理由的字符串
|
| 278 |
+
"""
|
| 279 |
+
complexity = query_features['complexity'].value
|
| 280 |
+
query_type = query_features['type'].value
|
| 281 |
+
file_count = file_features['file_count']
|
| 282 |
+
|
| 283 |
+
reasons = {
|
| 284 |
+
RAGMethod.BASIC: f"简单查询({complexity}),使用基础 RAG 方法即可",
|
| 285 |
+
RAGMethod.SUBQUERY: f"查询包含多个方面({query_features['question_count']}个问题,{complexity}),使用子查询分解以全面检索",
|
| 286 |
+
RAGMethod.HYDE: f"查询包含专业术语({complexity}),使用假设文档嵌入以改善语义检索",
|
| 287 |
+
RAGMethod.STEP_BACK: f"原理性查询({query_type},{complexity}),使用后退推理获取背景知识和原理",
|
| 288 |
+
RAGMethod.HYBRID_SUBQUERY_HYDE: f"复杂查询({complexity})+ {file_count}个文件,使用混合子查询+HyDE方法以全面检索",
|
| 289 |
+
RAGMethod.TRIPLE_HYBRID: f"非常复杂的查询({complexity})+ {file_count}个文件,使用三重混合方法(SubQuery+HyDE+Step-back)以获得最佳效果"
|
| 290 |
+
}
|
| 291 |
+
return reasons.get(method, f"使用 {method.value} 方法")
|
| 292 |
+
|
deep_agent_rag/rag/llm_adapter.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLM 适配器:将 LangChain ChatModel 包装成 OllamaLLM 接口
|
| 3 |
+
用于兼容 Learn_RAG 项目中的进阶 RAG 方法
|
| 4 |
+
"""
|
| 5 |
+
from typing import Optional
|
| 6 |
+
from langchain_core.messages import HumanMessage
|
| 7 |
+
from langchain_core.language_models.chat_models import BaseChatModel
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class LangChainLLMAdapter:
|
| 14 |
+
"""
|
| 15 |
+
将 LangChain ChatModel 适配为 OllamaLLM 接口
|
| 16 |
+
|
| 17 |
+
这个适配器允许 Learn_RAG 项目中的进阶 RAG 方法(需要 OllamaLLM)
|
| 18 |
+
使用 Deep_Agentic_AI_Tool 的统一 LLM 系统(Groq -> Ollama -> MLX)
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, langchain_llm: BaseChatModel):
|
| 22 |
+
"""
|
| 23 |
+
初始化适配器
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
langchain_llm: LangChain ChatModel 实例(来自 get_llm())
|
| 27 |
+
"""
|
| 28 |
+
self.llm = langchain_llm
|
| 29 |
+
self.model_name = self._detect_model_name()
|
| 30 |
+
self.base_url = "http://localhost:11434" # 默认值,实际不使用
|
| 31 |
+
self.timeout = 120 # 默认值,实际不使用
|
| 32 |
+
|
| 33 |
+
logger.info(f"✅ LLM 适配器初始化完成 (模型类型: {self.model_name})")
|
| 34 |
+
|
| 35 |
+
def _detect_model_name(self) -> str:
|
| 36 |
+
"""
|
| 37 |
+
检测 LLM 类型和模型名称
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
模型名称字符串
|
| 41 |
+
"""
|
| 42 |
+
llm_type = type(self.llm).__name__
|
| 43 |
+
|
| 44 |
+
# 检测 Groq
|
| 45 |
+
if "Groq" in llm_type or "ChatGroq" in llm_type:
|
| 46 |
+
model_name = getattr(self.llm, 'model_name', 'groq-unknown')
|
| 47 |
+
return f"groq:{model_name}"
|
| 48 |
+
|
| 49 |
+
# 检测 Ollama
|
| 50 |
+
if "Ollama" in llm_type or "ChatOllama" in llm_type:
|
| 51 |
+
model_name = getattr(self.llm, 'model', 'ollama-unknown')
|
| 52 |
+
return f"ollama:{model_name}"
|
| 53 |
+
|
| 54 |
+
# 检测 MLX
|
| 55 |
+
if "MLX" in llm_type or "MLXChatModel" in llm_type:
|
| 56 |
+
return "mlx:qwen2.5"
|
| 57 |
+
|
| 58 |
+
# 默认
|
| 59 |
+
return f"langchain:{llm_type}"
|
| 60 |
+
|
| 61 |
+
def _check_ollama_connection(self) -> bool:
|
| 62 |
+
"""
|
| 63 |
+
检查 Ollama 服务是否可用(兼容性方法,实际不使用)
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
总是返回 True(因为我们使用的是统一的 LLM 系统)
|
| 67 |
+
"""
|
| 68 |
+
return True
|
| 69 |
+
|
| 70 |
+
def _check_model_available(self) -> bool:
|
| 71 |
+
"""
|
| 72 |
+
检查模型是否可用(兼容性方法,实际不使用)
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
总是返回 True(因为我们使用的是统一的 LLM 系统)
|
| 76 |
+
"""
|
| 77 |
+
return True
|
| 78 |
+
|
| 79 |
+
def generate(
|
| 80 |
+
self,
|
| 81 |
+
prompt: str,
|
| 82 |
+
temperature: float = 0.7,
|
| 83 |
+
max_tokens: Optional[int] = None,
|
| 84 |
+
stream: bool = False
|
| 85 |
+
) -> str:
|
| 86 |
+
"""
|
| 87 |
+
生成回答(兼容 OllamaLLM.generate 接口)
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
prompt: 输入 prompt
|
| 91 |
+
temperature: 温度参数(0.0-1.0),控制随机性
|
| 92 |
+
max_tokens: 最大生成 token 数(None 表示使用模型预设)
|
| 93 |
+
stream: 是否使用流式输出(当前不支持,总是返回完整结果)
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
生成的回答字符串
|
| 97 |
+
"""
|
| 98 |
+
try:
|
| 99 |
+
# 将 prompt 转换为 LangChain 消息格式
|
| 100 |
+
messages = [HumanMessage(content=prompt)]
|
| 101 |
+
|
| 102 |
+
# 准备调用参数
|
| 103 |
+
invoke_kwargs = {}
|
| 104 |
+
|
| 105 |
+
# 如果 LLM 支持 temperature 参数
|
| 106 |
+
if hasattr(self.llm, 'temperature'):
|
| 107 |
+
# 临时设置 temperature(如果支持)
|
| 108 |
+
original_temp = getattr(self.llm, 'temperature', None)
|
| 109 |
+
try:
|
| 110 |
+
self.llm.temperature = temperature
|
| 111 |
+
except:
|
| 112 |
+
pass # 如果不支持设置,忽略
|
| 113 |
+
|
| 114 |
+
# 如果 LLM 支持 max_tokens 参数
|
| 115 |
+
if max_tokens and hasattr(self.llm, 'max_tokens'):
|
| 116 |
+
original_max_tokens = getattr(self.llm, 'max_tokens', None)
|
| 117 |
+
try:
|
| 118 |
+
self.llm.max_tokens = max_tokens
|
| 119 |
+
except:
|
| 120 |
+
pass # 如果不支持设置,忽略
|
| 121 |
+
|
| 122 |
+
# 调用 LangChain LLM
|
| 123 |
+
response = self.llm.invoke(messages, **invoke_kwargs)
|
| 124 |
+
|
| 125 |
+
# 恢复原始参数(如果之前修改过)
|
| 126 |
+
if hasattr(self.llm, 'temperature') and 'original_temp' in locals():
|
| 127 |
+
try:
|
| 128 |
+
self.llm.temperature = original_temp
|
| 129 |
+
except:
|
| 130 |
+
pass
|
| 131 |
+
|
| 132 |
+
if hasattr(self.llm, 'max_tokens') and 'original_max_tokens' in locals():
|
| 133 |
+
try:
|
| 134 |
+
self.llm.max_tokens = original_max_tokens
|
| 135 |
+
except:
|
| 136 |
+
pass
|
| 137 |
+
|
| 138 |
+
# 提取回答内容
|
| 139 |
+
if hasattr(response, 'content'):
|
| 140 |
+
answer = response.content
|
| 141 |
+
elif isinstance(response, str):
|
| 142 |
+
answer = response
|
| 143 |
+
else:
|
| 144 |
+
answer = str(response)
|
| 145 |
+
|
| 146 |
+
return answer.strip()
|
| 147 |
+
|
| 148 |
+
except Exception as e:
|
| 149 |
+
logger.error(f"⚠️ LLM 生成回答时出错: {e}")
|
| 150 |
+
raise RuntimeError(f"LLM 生成失败: {e}") from e
|
| 151 |
+
|
deep_agent_rag/rag/private_file_rag.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
deep_agent_rag/ui/calendar_interface.py
ADDED
|
@@ -0,0 +1,653 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# deep_agent_rag/ui/calendar_interface.py
|
| 2 |
+
|
| 3 |
+
import gradio as gr
|
| 4 |
+
from datetime import datetime, timedelta
|
| 5 |
+
import re
|
| 6 |
+
import json
|
| 7 |
+
import time
|
| 8 |
+
|
| 9 |
+
from ..agents.calendar_agent import generate_calendar_draft, create_calendar_draft
|
| 10 |
+
# Assuming is_using_local_llm might be used for warnings/status, similar to email_interface
|
| 11 |
+
# from ..utils.llm_utils import is_using_local_llm
|
| 12 |
+
|
| 13 |
+
# Agent log path for debugging (if needed)
|
| 14 |
+
log_path = "/Users/matthuang/Desktop/Deep_Agentic_AI_Tool/.cursor/debug.log"
|
| 15 |
+
|
| 16 |
+
def _create_calendar_interface():
|
| 17 |
+
"""創建 Calendar Tool 界面"""
|
| 18 |
+
gr.Markdown(
|
| 19 |
+
"""
|
| 20 |
+
### 📅 智能行事曆管理助手
|
| 21 |
+
|
| 22 |
+
使用 AI 根據您的完整提示自動生成行事曆事件草稿,您可以在創建前檢查和修改。
|
| 23 |
+
|
| 24 |
+
**使用方式:**
|
| 25 |
+
1. **快速選擇**:點擊下方常見事件按鈕,自動生成草稿
|
| 26 |
+
2. **自定義輸入**:在下方輸入完整的事件提示,包含:事件、日期、時間、地點、參與者
|
| 27 |
+
3. 查看 AI 反思評估結果和改進建議(如有)
|
| 28 |
+
4. 如果有缺失的資訊(如時間),系統會顯示下拉選單讓您選擇
|
| 29 |
+
5. 檢查並修改生成的事件內容
|
| 30 |
+
6. 確認無誤後點擊「創建事件」按鈕
|
| 31 |
+
|
| 32 |
+
**✨ 新功能:AI 迭代反思評估 + Google Maps 地點驗證**
|
| 33 |
+
- 系統會自動進行多輪反思評估(最多 3 輪)
|
| 34 |
+
- 自動驗證並標準化地址,計算交通時間
|
| 35 |
+
- 每輪評估後,如果有改進建議,會自動生成改進版本
|
| 36 |
+
- 改進後的版本會再次評估,直到 AI 認為滿意為止
|
| 37 |
+
"""
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# 快速選擇按鈕區域
|
| 41 |
+
gr.Markdown("### 🚀 快速選擇常見事件")
|
| 42 |
+
with gr.Row():
|
| 43 |
+
quick_meeting_btn = gr.Button("📋 團隊會議", variant="secondary", scale=1)
|
| 44 |
+
quick_client_btn = gr.Button("🤝 客戶拜訪", variant="secondary", scale=1)
|
| 45 |
+
quick_lunch_btn = gr.Button("🍽️ 午餐會議", variant="secondary", scale=1)
|
| 46 |
+
quick_oneonone_btn = gr.Button("💬 一對一會議", variant="secondary", scale=1)
|
| 47 |
+
with gr.Row():
|
| 48 |
+
quick_project_btn = gr.Button("📊 項目討論", variant="secondary", scale=1)
|
| 49 |
+
quick_training_btn = gr.Button("🎓 培訓/學習", variant="secondary", scale=1)
|
| 50 |
+
quick_social_btn = gr.Button("🎉 社交活動", variant="secondary", scale=1)
|
| 51 |
+
quick_custom_btn = gr.Button("✏️ 自定義輸入", variant="secondary", scale=1)
|
| 52 |
+
|
| 53 |
+
with gr.Row():
|
| 54 |
+
with gr.Column(scale=1):
|
| 55 |
+
# 單一 prompt 輸入
|
| 56 |
+
calendar_prompt_input = gr.Textbox(
|
| 57 |
+
label="📝 事件提示(包含事件、日期、時間、地點、參與者)",
|
| 58 |
+
placeholder="例如:明天下午2點團隊會議,討論項目進度,地點在會議室A,參與者包括john@example.com和mary@example.com",
|
| 59 |
+
lines=5,
|
| 60 |
+
value=""
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# 按鈕
|
| 64 |
+
with gr.Row():
|
| 65 |
+
generate_draft_btn = gr.Button("📝 生成事件草稿", variant="primary", scale=1)
|
| 66 |
+
clear_calendar_btn = gr.Button("🗑️ 清除", variant="secondary", scale=1)
|
| 67 |
+
|
| 68 |
+
# 狀態顯示
|
| 69 |
+
calendar_status_display = gr.Textbox(
|
| 70 |
+
label="📊 狀態",
|
| 71 |
+
value="等待操作...",
|
| 72 |
+
interactive=False,
|
| 73 |
+
lines=2
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# 反思結果顯示
|
| 77 |
+
calendar_reflection_display = gr.Textbox(
|
| 78 |
+
label="🔍 AI 反思評估",
|
| 79 |
+
value="等待生成事件...",
|
| 80 |
+
interactive=False,
|
| 81 |
+
lines=8,
|
| 82 |
+
visible=True
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# 缺失資訊的補充區域(動態顯示)
|
| 86 |
+
missing_info_group = gr.Group(visible=False)
|
| 87 |
+
with missing_info_group:
|
| 88 |
+
gr.Markdown("**⚠️ 請補充以下缺失的資訊:**")
|
| 89 |
+
|
| 90 |
+
# 日期選擇(如果缺失)
|
| 91 |
+
missing_date_display = gr.Dropdown(
|
| 92 |
+
label="📆 選擇日期",
|
| 93 |
+
choices=[],
|
| 94 |
+
visible=False,
|
| 95 |
+
interactive=True
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# 時間選擇(如果缺失)
|
| 99 |
+
missing_time_display = gr.Dropdown(
|
| 100 |
+
label="🕐 選擇時間",
|
| 101 |
+
choices=[],
|
| 102 |
+
visible=False,
|
| 103 |
+
interactive=True
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
fill_missing_btn = gr.Button("✅ 確認補充資訊", variant="primary", visible=False)
|
| 107 |
+
|
| 108 |
+
# 隱藏狀態變數,用於存儲 event_dict
|
| 109 |
+
event_dict_storage = gr.State(value={})
|
| 110 |
+
|
| 111 |
+
with gr.Column(scale=1):
|
| 112 |
+
# 事件詳情顯示和編輯區域
|
| 113 |
+
event_summary_display = gr.Textbox(
|
| 114 |
+
label="📌 事件標題",
|
| 115 |
+
placeholder="事件標題將在這裡顯示",
|
| 116 |
+
lines=1,
|
| 117 |
+
interactive=True
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
event_start_display = gr.Textbox(
|
| 121 |
+
label="🕐 開始時間",
|
| 122 |
+
placeholder="開始時間將在這裡顯示(格式: YYYY-MM-DDTHH:MM:SS+08:00)",
|
| 123 |
+
lines=1,
|
| 124 |
+
interactive=True
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
event_end_display = gr.Textbox(
|
| 128 |
+
label="🕐 結束時間",
|
| 129 |
+
placeholder="結束時間將在這裡顯示(格式: YYYY-MM-DDTHH:MM:SS+08:00)",
|
| 130 |
+
lines=1,
|
| 131 |
+
interactive=True
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
event_description_display = gr.Textbox(
|
| 135 |
+
label="📄 事件描述(可編輯)",
|
| 136 |
+
placeholder="事件描述將在這裡顯示,您可以編輯",
|
| 137 |
+
lines=6,
|
| 138 |
+
interactive=True
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
event_location_display = gr.Textbox(
|
| 142 |
+
label="📍 地點(可編輯,已自動驗證並標準化)",
|
| 143 |
+
placeholder="事件地點將在這裡顯示,您可以編輯",
|
| 144 |
+
lines=2,
|
| 145 |
+
interactive=True
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
event_attendees_display = gr.Textbox(
|
| 149 |
+
label="👥 參與者郵箱(可編輯,多個用逗號分隔)",
|
| 150 |
+
placeholder="參與者郵箱將在這裡顯示,您可以編輯",
|
| 151 |
+
lines=1,
|
| 152 |
+
interactive=True
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
# 創建按鈕
|
| 156 |
+
create_event_btn = gr.Button("✅ 創建事件", variant="primary", scale=1)
|
| 157 |
+
|
| 158 |
+
# 操作結果顯示
|
| 159 |
+
calendar_result_display = gr.Textbox(
|
| 160 |
+
label="📊 操作結果",
|
| 161 |
+
lines=8,
|
| 162 |
+
interactive=False
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# 生成時間選項(每30分鐘一個選項)
|
| 166 |
+
def generate_time_options():
|
| 167 |
+
"""生成時間選項列表"""
|
| 168 |
+
times = []
|
| 169 |
+
for hour in range(24):
|
| 170 |
+
for minute in [0, 30]:
|
| 171 |
+
time_str = f"{hour:02d}:{minute:02d}"
|
| 172 |
+
times.append(time_str)
|
| 173 |
+
return times
|
| 174 |
+
|
| 175 |
+
# 生成日期選項(今天、明天、後天,以及未來7天)
|
| 176 |
+
def generate_date_options():
|
| 177 |
+
"""生成日期選項列表"""
|
| 178 |
+
dates = []
|
| 179 |
+
today = datetime.now()
|
| 180 |
+
date_names = ["今天", "明天", "後天"]
|
| 181 |
+
|
| 182 |
+
for i in range(3):
|
| 183 |
+
date_obj = today + timedelta(days=i)
|
| 184 |
+
date_str = date_obj.strftime('%Y-%m-%d')
|
| 185 |
+
dates.append(f"{date_names[i]} ({date_str})")
|
| 186 |
+
|
| 187 |
+
for i in range(3, 7):
|
| 188 |
+
date_obj = today + timedelta(days=i)
|
| 189 |
+
date_str = date_obj.strftime('%Y-%m-%d')
|
| 190 |
+
dates.append(date_str)
|
| 191 |
+
|
| 192 |
+
return dates
|
| 193 |
+
|
| 194 |
+
# 快速選擇事件模板生成函數
|
| 195 |
+
def generate_quick_prompt(event_type: str) -> str:
|
| 196 |
+
"""根據事件類型生成預設提示"""
|
| 197 |
+
from datetime import datetime, timedelta
|
| 198 |
+
|
| 199 |
+
# 獲取明天的日期
|
| 200 |
+
tomorrow = datetime.now() + timedelta(days=1)
|
| 201 |
+
tomorrow_str = tomorrow.strftime("%Y-%m-%d")
|
| 202 |
+
|
| 203 |
+
templates = {
|
| 204 |
+
"meeting": f"明天下午2點團隊會議,討論項目進度和下週計劃,地點在會議室,參與者包括團隊成員",
|
| 205 |
+
"client": f"明天上午10點客戶拜訪,討論合作方案和需求,地點在客戶公司或會議室",
|
| 206 |
+
"lunch": f"明天中午12點午餐會議,與合作夥伴討論業務合作,地點在附近的餐廳",
|
| 207 |
+
"oneonone": f"明天下午3點一對一會議,討論工作進展和職業發展,地點在會議室或咖啡廳",
|
| 208 |
+
"project": f"明天上午9點項目討論會議,審查項目進度和解決問題,地點在項目室,參與者包括項目團隊",
|
| 209 |
+
"training": f"明天下午2點培訓課程,學習新技能和最佳實踐,地點在培訓室或線上",
|
| 210 |
+
"social": f"明天晚上6點團隊聚餐,慶祝項目完成,地點在餐廳,參與者包括團隊成員",
|
| 211 |
+
"custom": "" # 自定義,返回空讓用戶輸入
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
return templates.get(event_type, "")
|
| 215 |
+
|
| 216 |
+
# 快速選擇按鈕處理函數(自動生成草稿)
|
| 217 |
+
def quick_select_and_generate(event_type: str):
|
| 218 |
+
"""快速選擇事件類型並自動生成草稿"""
|
| 219 |
+
prompt = generate_quick_prompt(event_type)
|
| 220 |
+
if not prompt:
|
| 221 |
+
# 如果是自定義,只返回空提示,不自動生成
|
| 222 |
+
return (
|
| 223 |
+
prompt, # calendar_prompt_input
|
| 224 |
+
"請在下方輸入框中輸入事件提示,然後點擊「生成事件草稿」", # calendar_status_display
|
| 225 |
+
"等待輸入...", # calendar_reflection_display
|
| 226 |
+
gr.update(visible=False), # missing_info_group
|
| 227 |
+
gr.update(visible=False, choices=[]), # missing_date_display
|
| 228 |
+
gr.update(visible=False, choices=[]), # missing_time_display
|
| 229 |
+
gr.update(visible=False), # fill_missing_btn
|
| 230 |
+
"", "", "", "", "", "", # event fields
|
| 231 |
+
{},
|
| 232 |
+
"" # calendar_result_display
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
# 自動生成草稿(調用 generate_draft 並返回所有輸出)
|
| 236 |
+
draft_result = generate_draft(prompt)
|
| 237 |
+
# generate_draft 返回的格式是:(status, reflection_display, missing_info_group, ...)
|
| 238 |
+
# 但我們需要返回 (prompt, status, reflection_display, ...)
|
| 239 |
+
# draft_result 是一個元組,我們需要將 prompt 添加到開頭
|
| 240 |
+
return (prompt,) + draft_result
|
| 241 |
+
|
| 242 |
+
def quick_select_meeting():
|
| 243 |
+
"""快速選擇:團隊會議"""
|
| 244 |
+
return quick_select_and_generate("meeting")
|
| 245 |
+
|
| 246 |
+
def quick_select_client():
|
| 247 |
+
"""快速選擇:客戶拜訪"""
|
| 248 |
+
return quick_select_and_generate("client")
|
| 249 |
+
|
| 250 |
+
def quick_select_lunch():
|
| 251 |
+
"""快速選擇:午餐會議"""
|
| 252 |
+
return quick_select_and_generate("lunch")
|
| 253 |
+
|
| 254 |
+
def quick_select_oneonone():
|
| 255 |
+
"""快速選擇:一對一會議"""
|
| 256 |
+
return quick_select_and_generate("oneonone")
|
| 257 |
+
|
| 258 |
+
def quick_select_project():
|
| 259 |
+
"""快速選擇:項目討論"""
|
| 260 |
+
return quick_select_and_generate("project")
|
| 261 |
+
|
| 262 |
+
def quick_select_training():
|
| 263 |
+
"""快速選擇:培訓/學習"""
|
| 264 |
+
return quick_select_and_generate("training")
|
| 265 |
+
|
| 266 |
+
def quick_select_social():
|
| 267 |
+
"""快速選擇:社交活動"""
|
| 268 |
+
return quick_select_and_generate("social")
|
| 269 |
+
|
| 270 |
+
def quick_select_custom():
|
| 271 |
+
"""快速選擇:自定義輸入(只清空,不自動生成)"""
|
| 272 |
+
return (
|
| 273 |
+
"", # calendar_prompt_input
|
| 274 |
+
"請在下方輸入框中輸入事件提示,然後點擊「生成事件草稿」", # calendar_status_display
|
| 275 |
+
"等待輸入...", # calendar_reflection_display
|
| 276 |
+
gr.update(visible=False), # missing_info_group
|
| 277 |
+
gr.update(visible=False, choices=[]), # missing_date_display
|
| 278 |
+
gr.update(visible=False, choices=[]), # missing_time_display
|
| 279 |
+
gr.update(visible=False), # fill_missing_btn
|
| 280 |
+
"", "", "", "", "", "", # event fields
|
| 281 |
+
{},
|
| 282 |
+
"" # calendar_result_display
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
# 事件處理函數
|
| 286 |
+
def generate_draft(prompt):
|
| 287 |
+
"""生成行事曆事件草稿(包含反思功能)"""
|
| 288 |
+
if not prompt or not prompt.strip():
|
| 289 |
+
return (
|
| 290 |
+
"❌ 請輸入事件提示",
|
| 291 |
+
"❌ 請輸入事件提示",
|
| 292 |
+
gr.update(visible=False),
|
| 293 |
+
gr.update(visible=False, choices=[]),
|
| 294 |
+
gr.update(visible=False, choices=[]),
|
| 295 |
+
gr.update(visible=False),
|
| 296 |
+
"", "", "", "", "", "", "",
|
| 297 |
+
"❌ 請輸入事件提示"
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
try:
|
| 301 |
+
status_msg = "🔄 正在生成事件草稿..."
|
| 302 |
+
|
| 303 |
+
# 生成事件草稿(包含反思功能)
|
| 304 |
+
event_dict, status, missing_info, reflection_result, was_improved = generate_calendar_draft(
|
| 305 |
+
prompt.strip(), enable_reflection=True
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
if not event_dict:
|
| 309 |
+
return (
|
| 310 |
+
status,
|
| 311 |
+
gr.update(visible=False),
|
| 312 |
+
gr.update(visible=False, choices=[]),
|
| 313 |
+
gr.update(visible=False, choices=[]),
|
| 314 |
+
gr.update(visible=False),
|
| 315 |
+
"", "", "", "", "", "", "",
|
| 316 |
+
status
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
# 格式化反思結果顯示
|
| 320 |
+
if reflection_result:
|
| 321 |
+
# 計算反思輪數
|
| 322 |
+
reflection_count = reflection_result.count("【第") if "【第" in reflection_result else 0
|
| 323 |
+
|
| 324 |
+
if was_improved:
|
| 325 |
+
if reflection_count > 1:
|
| 326 |
+
reflection_display = (
|
| 327 |
+
f"🔍 **AI 迭代反思評估結果**(共 {reflection_count} 輪)\n\n"
|
| 328 |
+
f"{reflection_result}\n\n"
|
| 329 |
+
f"✨ **已自動應用改進建議,經過 {reflection_count} 輪優化,當前顯示的是最終優化版本**"
|
| 330 |
+
)
|
| 331 |
+
else:
|
| 332 |
+
reflection_display = (
|
| 333 |
+
f"🔍 **AI 反思評估結果**\n\n"
|
| 334 |
+
f"{reflection_result}\n\n"
|
| 335 |
+
f"✨ **已自動應用改進建議,當前顯示的是優化後的版本**"
|
| 336 |
+
)
|
| 337 |
+
else:
|
| 338 |
+
reflection_display = (
|
| 339 |
+
f"🔍 **AI 反思評估結果**\n\n"
|
| 340 |
+
f"{reflection_result}\n\n"
|
| 341 |
+
f"✅ **事件質量良好,無需改進**"
|
| 342 |
+
)
|
| 343 |
+
else:
|
| 344 |
+
reflection_display = "⚠️ 反思功能未返回結果"
|
| 345 |
+
|
| 346 |
+
# 【Google Maps 整合】添加地點建議訊息
|
| 347 |
+
location_suggestion = event_dict.get("location_suggestion", "")
|
| 348 |
+
if location_suggestion:
|
| 349 |
+
# 將地點建議添加到狀態訊息中
|
| 350 |
+
if status:
|
| 351 |
+
status = f"{status}\n\n🗺️ **地點資訊:**\n{location_suggestion}"
|
| 352 |
+
else:
|
| 353 |
+
status = f"🗺️ **地點資訊:**\n{location_suggestion}"
|
| 354 |
+
|
| 355 |
+
# 檢查是否有缺失資訊
|
| 356 |
+
has_missing = bool(missing_info)
|
| 357 |
+
|
| 358 |
+
if has_missing:
|
| 359 |
+
# 顯示缺失資訊區域
|
| 360 |
+
date_visible = missing_info.get("date", False)
|
| 361 |
+
time_visible = missing_info.get("time", False)
|
| 362 |
+
|
| 363 |
+
date_choices = generate_date_options() if date_visible else []
|
| 364 |
+
time_choices = generate_time_options() if time_visible else []
|
| 365 |
+
|
| 366 |
+
return (
|
| 367 |
+
status,
|
| 368 |
+
reflection_display,
|
| 369 |
+
gr.update(visible=True), # 顯示缺失資訊區域
|
| 370 |
+
gr.update(visible=date_visible, choices=date_choices, value=date_choices[0] if date_choices else None),
|
| 371 |
+
gr.update(visible=time_visible, choices=time_choices, value=time_choices[0] if time_choices else None),
|
| 372 |
+
gr.update(visible=True), # 顯示確認按鈕
|
| 373 |
+
event_dict.get("summary", ""),
|
| 374 |
+
event_dict.get("start_datetime", ""),
|
| 375 |
+
event_dict.get("end_datetime", ""),
|
| 376 |
+
event_dict.get("description", ""),
|
| 377 |
+
event_dict.get("location", ""),
|
| 378 |
+
event_dict.get("attendees", ""),
|
| 379 |
+
event_dict, # 傳遞完整的事件字典以便後續使用
|
| 380 |
+
""
|
| 381 |
+
)
|
| 382 |
+
else:
|
| 383 |
+
# 沒有缺失資訊,直接顯示結果
|
| 384 |
+
return (
|
| 385 |
+
status,
|
| 386 |
+
reflection_display,
|
| 387 |
+
gr.update(visible=False),
|
| 388 |
+
gr.update(visible=False, choices=[]),
|
| 389 |
+
gr.update(visible=False, choices=[]),
|
| 390 |
+
gr.update(visible=False),
|
| 391 |
+
event_dict.get("summary", ""),
|
| 392 |
+
event_dict.get("start_datetime", ""),
|
| 393 |
+
event_dict.get("end_datetime", ""),
|
| 394 |
+
event_dict.get("description", ""),
|
| 395 |
+
event_dict.get("location", ""),
|
| 396 |
+
event_dict.get("attendees", ""),
|
| 397 |
+
event_dict,
|
| 398 |
+
""
|
| 399 |
+
)
|
| 400 |
+
except Exception as e:
|
| 401 |
+
error_msg = f"❌ 發生錯誤:{str(e)}"
|
| 402 |
+
print(f"Calendar Tool 錯誤:{e}")
|
| 403 |
+
import traceback
|
| 404 |
+
traceback.print_exc()
|
| 405 |
+
return (
|
| 406 |
+
"❌ 發生錯誤",
|
| 407 |
+
f"❌ 發生錯誤:{str(e)}",
|
| 408 |
+
gr.update(visible=False),
|
| 409 |
+
gr.update(visible=False, choices=[]),
|
| 410 |
+
gr.update(visible=False, choices=[]),
|
| 411 |
+
gr.update(visible=False),
|
| 412 |
+
"", "", "", "", "", "", {},
|
| 413 |
+
error_msg
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
def fill_missing_info(event_dict_storage, selected_date, selected_time):
|
| 417 |
+
"""填充缺失的資訊"""
|
| 418 |
+
if not event_dict_storage:
|
| 419 |
+
return (
|
| 420 |
+
"❌ 沒有事件資料",
|
| 421 |
+
gr.update(visible=False),
|
| 422 |
+
gr.update(visible=False, choices=[]),
|
| 423 |
+
gr.update(visible=False, choices=[]),
|
| 424 |
+
gr.update(visible=False),
|
| 425 |
+
"", "", "", "", "", "",
|
| 426 |
+
{}
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
# 更新日期和時間
|
| 430 |
+
if selected_date:
|
| 431 |
+
# 從選項中提取日期字串(例如:"明天 (2026-01-25)" -> "2026-01-25")
|
| 432 |
+
if "(" in selected_date:
|
| 433 |
+
date_str = selected_date.split("(")[1].split(")")[0]
|
| 434 |
+
else:
|
| 435 |
+
date_str = selected_date
|
| 436 |
+
else:
|
| 437 |
+
date_str = event_dict_storage.get("date", "今天")
|
| 438 |
+
|
| 439 |
+
if selected_time:
|
| 440 |
+
time_str = selected_time
|
| 441 |
+
else:
|
| 442 |
+
time_str = "09:00" # 預設時間
|
| 443 |
+
|
| 444 |
+
# 重新解析日期和時間
|
| 445 |
+
from ..agents.calendar_agent import parse_datetime # Import here to avoid circular dependency or unnecessary global import
|
| 446 |
+
start_datetime, end_datetime = parse_datetime(date_str, time_str)
|
| 447 |
+
|
| 448 |
+
# 更新事件字典
|
| 449 |
+
event_dict_storage["start_datetime"] = start_datetime
|
| 450 |
+
event_dict_storage["end_datetime"] = end_datetime
|
| 451 |
+
|
| 452 |
+
return (
|
| 453 |
+
"✅ 資訊已補充,請檢查並創建事件",
|
| 454 |
+
gr.update(visible=False), # 隱藏缺失資訊區域
|
| 455 |
+
gr.update(visible=False, choices=[]),
|
| 456 |
+
gr.update(visible=False, choices=[]),
|
| 457 |
+
gr.update(visible=False),
|
| 458 |
+
event_dict_storage.get("summary", ""),
|
| 459 |
+
start_datetime,
|
| 460 |
+
end_datetime,
|
| 461 |
+
event_dict_storage.get("description", ""),
|
| 462 |
+
event_dict_storage.get("location", ""),
|
| 463 |
+
event_dict_storage.get("attendees", ""),
|
| 464 |
+
event_dict_storage
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
def create_event(summary, start_datetime, end_datetime, description, location, attendees):
|
| 468 |
+
"""創建行事曆事件"""
|
| 469 |
+
if not summary or not summary.strip():
|
| 470 |
+
return "❌ 請輸入事件標題", "❌ 請輸入事件標題"
|
| 471 |
+
|
| 472 |
+
if not start_datetime or not start_datetime.strip():
|
| 473 |
+
return "❌ 請輸入開始時間", "❌ 請輸入開始時間"
|
| 474 |
+
|
| 475 |
+
if not end_datetime or not end_datetime.strip():
|
| 476 |
+
return "❌ 請輸入結束時間", "❌ 請輸入結束時間"
|
| 477 |
+
|
| 478 |
+
try:
|
| 479 |
+
status_msg = "🔄 正在創建事件..."
|
| 480 |
+
|
| 481 |
+
# 構建事件字典
|
| 482 |
+
event_dict = {
|
| 483 |
+
"summary": summary.strip(),
|
| 484 |
+
"start_datetime": start_datetime.strip(),
|
| 485 |
+
"end_datetime": end_datetime.strip(),
|
| 486 |
+
"description": description.strip() if description else "",
|
| 487 |
+
"location": location.strip() if location else "",
|
| 488 |
+
"attendees": attendees.strip() if attendees else "",
|
| 489 |
+
"timezone": "Asia/Taipei"
|
| 490 |
+
}
|
| 491 |
+
|
| 492 |
+
# 創建事件
|
| 493 |
+
result = create_calendar_draft(event_dict)
|
| 494 |
+
|
| 495 |
+
return "✅ 事件已創建", result
|
| 496 |
+
except Exception as e:
|
| 497 |
+
error_msg = f"❌ 創建事件時發生錯誤:{str(e)}"
|
| 498 |
+
print(f"Calendar Tool 錯誤:{e}")
|
| 499 |
+
import traceback
|
| 500 |
+
traceback.print_exc()
|
| 501 |
+
return "❌ 發生錯誤", error_msg
|
| 502 |
+
|
| 503 |
+
def clear_calendar():
|
| 504 |
+
"""清除行事曆相關輸入和輸出"""
|
| 505 |
+
return (
|
| 506 |
+
"", # prompt
|
| 507 |
+
"等待操作...", # status
|
| 508 |
+
"等待生成事件...", # reflection_display
|
| 509 |
+
gr.update(visible=False), # missing_info_group
|
| 510 |
+
gr.update(visible=False, choices=[]), # missing_date
|
| 511 |
+
gr.update(visible=False, choices=[]), # missing_time
|
| 512 |
+
gr.update(visible=False), # fill_missing_btn
|
| 513 |
+
"", "", "", "", "", "", # event fields
|
| 514 |
+
{},
|
| 515 |
+
"" # result
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
# 綁定事件
|
| 519 |
+
generate_draft_btn.click(
|
| 520 |
+
fn=generate_draft,
|
| 521 |
+
inputs=[calendar_prompt_input],
|
| 522 |
+
outputs=[
|
| 523 |
+
calendar_status_display,
|
| 524 |
+
calendar_reflection_display,
|
| 525 |
+
missing_info_group,
|
| 526 |
+
missing_date_display,
|
| 527 |
+
missing_time_display,
|
| 528 |
+
fill_missing_btn,
|
| 529 |
+
event_summary_display,
|
| 530 |
+
event_start_display,
|
| 531 |
+
event_end_display,
|
| 532 |
+
event_description_display,
|
| 533 |
+
event_location_display,
|
| 534 |
+
event_attendees_display,
|
| 535 |
+
event_dict_storage,
|
| 536 |
+
calendar_result_display
|
| 537 |
+
]
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
# 綁定快速選擇按鈕(自動填充提示並生成草稿)
|
| 541 |
+
quick_outputs = [
|
| 542 |
+
calendar_prompt_input, # 更新提示輸入框
|
| 543 |
+
calendar_status_display,
|
| 544 |
+
calendar_reflection_display,
|
| 545 |
+
missing_info_group,
|
| 546 |
+
missing_date_display,
|
| 547 |
+
missing_time_display,
|
| 548 |
+
fill_missing_btn,
|
| 549 |
+
event_summary_display,
|
| 550 |
+
event_start_display,
|
| 551 |
+
event_end_display,
|
| 552 |
+
event_description_display,
|
| 553 |
+
event_location_display,
|
| 554 |
+
event_attendees_display,
|
| 555 |
+
event_dict_storage,
|
| 556 |
+
calendar_result_display
|
| 557 |
+
]
|
| 558 |
+
|
| 559 |
+
quick_meeting_btn.click(fn=quick_select_meeting, outputs=quick_outputs)
|
| 560 |
+
quick_client_btn.click(fn=quick_select_client, outputs=quick_outputs)
|
| 561 |
+
quick_lunch_btn.click(fn=quick_select_lunch, outputs=quick_outputs)
|
| 562 |
+
quick_oneonone_btn.click(fn=quick_select_oneonone, outputs=quick_outputs)
|
| 563 |
+
quick_project_btn.click(fn=quick_select_project, outputs=quick_outputs)
|
| 564 |
+
quick_training_btn.click(fn=quick_select_training, outputs=quick_outputs)
|
| 565 |
+
quick_social_btn.click(fn=quick_select_social, outputs=quick_outputs)
|
| 566 |
+
quick_custom_btn.click(fn=quick_select_custom, outputs=quick_outputs)
|
| 567 |
+
|
| 568 |
+
fill_missing_btn.click(
|
| 569 |
+
fn=fill_missing_info,
|
| 570 |
+
inputs=[event_dict_storage, missing_date_display, missing_time_display],
|
| 571 |
+
outputs=[
|
| 572 |
+
calendar_status_display,
|
| 573 |
+
missing_info_group,
|
| 574 |
+
missing_date_display,
|
| 575 |
+
missing_time_display,
|
| 576 |
+
fill_missing_btn,
|
| 577 |
+
event_summary_display,
|
| 578 |
+
event_start_display,
|
| 579 |
+
event_end_display,
|
| 580 |
+
event_description_display,
|
| 581 |
+
event_location_display,
|
| 582 |
+
event_attendees_display,
|
| 583 |
+
event_dict_storage
|
| 584 |
+
]
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
create_event_btn.click(
|
| 588 |
+
fn=create_event,
|
| 589 |
+
inputs=[
|
| 590 |
+
event_summary_display,
|
| 591 |
+
event_start_display,
|
| 592 |
+
event_end_display,
|
| 593 |
+
event_description_display,
|
| 594 |
+
event_location_display,
|
| 595 |
+
event_attendees_display
|
| 596 |
+
],
|
| 597 |
+
outputs=[calendar_status_display, calendar_result_display]
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
clear_calendar_btn.click(
|
| 601 |
+
fn=clear_calendar,
|
| 602 |
+
outputs=[
|
| 603 |
+
calendar_prompt_input,
|
| 604 |
+
calendar_status_display,
|
| 605 |
+
calendar_reflection_display,
|
| 606 |
+
missing_info_group,
|
| 607 |
+
missing_date_display,
|
| 608 |
+
missing_time_display,
|
| 609 |
+
fill_missing_btn,
|
| 610 |
+
event_summary_display,
|
| 611 |
+
event_start_display,
|
| 612 |
+
event_end_display,
|
| 613 |
+
event_description_display,
|
| 614 |
+
event_location_display,
|
| 615 |
+
event_attendees_display,
|
| 616 |
+
event_dict_storage,
|
| 617 |
+
calendar_result_display
|
| 618 |
+
]
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
# 示例
|
| 622 |
+
gr.Examples(
|
| 623 |
+
examples=[
|
| 624 |
+
"明天下午2點團隊會議,討論項目進度,地點在會議室A,參與者包括john@example.com",
|
| 625 |
+
"2026-01-25 上午9點產品發布會,介紹新功能和改進,地點在總部大樓",
|
| 626 |
+
"後天下午3點客戶會議,討論合作細節,參與者包括客戶代表",
|
| 627 |
+
"下週一上午10點技術分享會,分享最新的 AI 技術,地點在研發中心"
|
| 628 |
+
],
|
| 629 |
+
inputs=[calendar_prompt_input]
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
# 頁腳說明
|
| 633 |
+
gr.Markdown(
|
| 634 |
+
"""
|
| 635 |
+
---
|
| 636 |
+
**注意事項:**
|
| 637 |
+
1. 使用 Google Calendar API 管理行事曆事件
|
| 638 |
+
2. 首次使用需要在專案根目錄放置 `credentials.json`(從 Google Cloud Console 下載的 OAuth2 憑證)
|
| 639 |
+
3. 首次運行時會自動開啟瀏覽器進行授權,授權後會生成 `token.json` 文件
|
| 640 |
+
4. 事件內容由 AI 自動生成,請在創建前檢查結果
|
| 641 |
+
5. 在提示中包含所有資訊:事件、日期、時間、地點、參與者
|
| 642 |
+
6. 如果缺少日期或時間,系統會顯示下拉選單讓您選擇
|
| 643 |
+
7. 日期格式支援:YYYY-MM-DD(例如:2026-01-25)或相對日期(今天、明天、後天)
|
| 644 |
+
8. 時間格式支援:24小時制(14:00)或12小時制(2:00 PM)
|
| 645 |
+
|
| 646 |
+
**設置步驟:**
|
| 647 |
+
- 前往 [Google Cloud Console](https://console.cloud.google.com/) 創建專案
|
| 648 |
+
- 啟用 Google Calendar API
|
| 649 |
+
- 創建 OAuth2 憑證並下載為 `credentials.json`
|
| 650 |
+
- 將 `credentials.json` 放在專案根目錄
|
| 651 |
+
- 確保授予 Calendar API 的完整存取權限
|
| 652 |
+
"""
|
| 653 |
+
)
|
deep_agent_rag/ui/email_interface.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# deep_agent_rag/ui/email_interface.py
|
| 2 |
+
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import re
|
| 5 |
+
import json
|
| 6 |
+
import time
|
| 7 |
+
|
| 8 |
+
from ..agents.email_agent import generate_email_draft, send_email_draft
|
| 9 |
+
from ..config import EMAIL_SENDER
|
| 10 |
+
from ..utils.llm_utils import is_using_local_llm # Assuming this might be used for warnings/status
|
| 11 |
+
|
| 12 |
+
# Agent log path for debugging (if needed)
|
| 13 |
+
log_path = "/Users/matthuang/Desktop/Deep_Agentic_AI_Tool/.cursor/debug.log"
|
| 14 |
+
|
| 15 |
+
def _create_email_interface():
|
| 16 |
+
"""創建 Email Tool 界面"""
|
| 17 |
+
gr.Markdown(
|
| 18 |
+
f"""
|
| 19 |
+
### 📧 智能郵件助手
|
| 20 |
+
|
| 21 |
+
使用 AI 根據您的關鍵提示自動生成專業郵件草稿,您可以在發送前檢查和修改。
|
| 22 |
+
|
| 23 |
+
**寄件者:** {EMAIL_SENDER}
|
| 24 |
+
|
| 25 |
+
**使用方式:**
|
| 26 |
+
1. 在下方輸入郵件提示(例如:"寫一封感謝信"、"邀請參加會議"等)
|
| 27 |
+
2. 輸入收件人 Gmail 郵箱地址(僅支援 @gmail.com 或 @googlemail.com)
|
| 28 |
+
3. 點擊「生成郵件草稿」按鈕
|
| 29 |
+
4. 查看 AI 反思評估結果和改進建議(如有)
|
| 30 |
+
5. 檢查並修改生成的郵件內容(特別是簽名部分)
|
| 31 |
+
6. 確認無誤後點擊「發送郵件」按鈕
|
| 32 |
+
|
| 33 |
+
**✨ 新功能:AI 迭代反思評估**
|
| 34 |
+
- 系統會自動進行多輪反思評估(最多 3 輪)
|
| 35 |
+
- 每輪評估後,如果有改進建議,會自動生成改進版本
|
| 36 |
+
- 改進後的版本會再次評估,直到 AI 認為滿意為止
|
| 37 |
+
- 您可以看到完整的反思過程和每輪的改進建議
|
| 38 |
+
|
| 39 |
+
**注意:此工具僅支援 Gmail 郵箱,收件人必須使用 Gmail 郵箱地址。**
|
| 40 |
+
"""
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
with gr.Row():
|
| 44 |
+
with gr.Column(scale=1):
|
| 45 |
+
# 郵件提示輸入
|
| 46 |
+
email_prompt_input = gr.Textbox(
|
| 47 |
+
label="📝 郵件提示",
|
| 48 |
+
placeholder="例如:寫一封感謝信,感謝對方在項目中的幫助",
|
| 49 |
+
lines=5,
|
| 50 |
+
value="寫一封專業的郵件,介紹我們的 AI 產品"
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# 收件人輸入
|
| 54 |
+
recipient_input = gr.Textbox(
|
| 55 |
+
label="📮 收件人郵箱(僅支援 Gmail)",
|
| 56 |
+
placeholder="recipient@gmail.com",
|
| 57 |
+
lines=1
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# 按鈕
|
| 61 |
+
with gr.Row():
|
| 62 |
+
generate_draft_btn = gr.Button("📝 生成郵件草稿", variant="primary", scale=1)
|
| 63 |
+
clear_email_btn = gr.Button("🗑️ 清除", variant="secondary", scale=1)
|
| 64 |
+
|
| 65 |
+
# 狀態顯示
|
| 66 |
+
email_status_display = gr.Textbox(
|
| 67 |
+
label="📊 狀態",
|
| 68 |
+
value="等待操作...",
|
| 69 |
+
interactive=False,
|
| 70 |
+
lines=2
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# 反思結果顯示
|
| 74 |
+
email_reflection_display = gr.Textbox(
|
| 75 |
+
label="🔍 AI 反思評估",
|
| 76 |
+
value="等待生成郵件...",
|
| 77 |
+
interactive=False,
|
| 78 |
+
lines=8,
|
| 79 |
+
visible=True
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
with gr.Column(scale=1):
|
| 83 |
+
# 郵件主題(可編輯)
|
| 84 |
+
email_subject_input = gr.Textbox(
|
| 85 |
+
label="📌 郵件主題",
|
| 86 |
+
placeholder="郵件主題將在這裡顯示,您可以編輯",
|
| 87 |
+
lines=1,
|
| 88 |
+
interactive=True
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# 郵件正文(可編輯)
|
| 92 |
+
email_body_input = gr.Textbox(
|
| 93 |
+
label="📄 郵件正文(可編輯)",
|
| 94 |
+
placeholder="郵件內容將在這裡顯示,您可以編輯",
|
| 95 |
+
lines=15,
|
| 96 |
+
interactive=True
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# 發送按鈕
|
| 100 |
+
send_draft_btn = gr.Button("📧 發送郵件", variant="primary", scale=1)
|
| 101 |
+
|
| 102 |
+
# 發送結果顯示
|
| 103 |
+
email_result_display = gr.Textbox(
|
| 104 |
+
label="📊 發送結果",
|
| 105 |
+
lines=5,
|
| 106 |
+
interactive=False
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# 事件處理函數
|
| 110 |
+
def generate_draft(prompt, recipient):
|
| 111 |
+
"""生成郵件草稿(包含反思功能)"""
|
| 112 |
+
if not prompt or not prompt.strip():
|
| 113 |
+
return "❌ 請輸入郵件提示", "", "", "❌ 請輸入郵件提示", "❌ 請輸入郵件提示"
|
| 114 |
+
|
| 115 |
+
if not recipient or not recipient.strip():
|
| 116 |
+
return "❌ 請輸入收件人郵箱", "", "", "❌ 請輸入收件人郵箱", "❌ 請輸入收件人郵箱"
|
| 117 |
+
|
| 118 |
+
# 驗證郵箱格式和 Gmail 限制
|
| 119 |
+
if "@" not in recipient or "." not in recipient.split("@")[1]:
|
| 120 |
+
return "❌ 郵箱格式不正確", "", "", "❌ 郵箱格式不正確,請輸入有效的郵箱地址", "❌ 郵箱格式不正確,請輸入有效的郵箱地址"
|
| 121 |
+
|
| 122 |
+
# 驗證是否為 Gmail 郵箱
|
| 123 |
+
recipient_lower = recipient.strip().lower()
|
| 124 |
+
if not (recipient_lower.endswith("@gmail.com") or recipient_lower.endswith("@googlemail.com")):
|
| 125 |
+
return "❌ 僅支援 Gmail 郵箱", "", "", "❌ 此工具僅支援 Gmail 郵箱(@gmail.com 或 @googlemail.com),請輸入 Gmail 郵箱地址", "❌ 此工具僅支援 Gmail 郵箱(@gmail.com 或 @googlemail.com),請輸入 Gmail 郵箱地址"
|
| 126 |
+
|
| 127 |
+
try:
|
| 128 |
+
status_msg = "🔄 正在生成郵件草稿..."
|
| 129 |
+
reflection_msg = "🔄 正在生成郵件草稿..."
|
| 130 |
+
|
| 131 |
+
# 生成郵件草稿(包含反思功能,會自動改進)
|
| 132 |
+
subject, body, status, reflection_result, was_improved = generate_email_draft(
|
| 133 |
+
prompt, recipient.strip(), enable_reflection=True
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
if subject and body:
|
| 137 |
+
# 格式化反思結果顯示
|
| 138 |
+
if reflection_result:
|
| 139 |
+
# 計算反思輪數
|
| 140 |
+
reflection_count = reflection_result.count("【第") if "【第" in reflection_result else 0
|
| 141 |
+
|
| 142 |
+
if was_improved:
|
| 143 |
+
if reflection_count > 1:
|
| 144 |
+
reflection_display = (
|
| 145 |
+
f"🔍 **AI 迭代反思評估結果**(共 {reflection_count} 輪)\n\n"
|
| 146 |
+
f"{reflection_result}\n\n"
|
| 147 |
+
f"✨ **已自動應用改進建議,經過 {reflection_count} 輪優化,當前顯示的是最終優化版本**"
|
| 148 |
+
)
|
| 149 |
+
else:
|
| 150 |
+
reflection_display = (
|
| 151 |
+
f"🔍 **AI 反思評估結果**\n\n"
|
| 152 |
+
f"{reflection_result}\n\n"
|
| 153 |
+
f"✨ **已自動應用改進建議,當前顯示的是優化後的版本**"
|
| 154 |
+
)
|
| 155 |
+
else:
|
| 156 |
+
reflection_display = (
|
| 157 |
+
f"🔍 **AI 反思評估結果**\n\n"
|
| 158 |
+
f"{reflection_result}\n\n"
|
| 159 |
+
f"✅ **郵件質量良好,無需改進**"
|
| 160 |
+
)
|
| 161 |
+
else:
|
| 162 |
+
reflection_display = "⚠️ 反思功能未返回結果"
|
| 163 |
+
|
| 164 |
+
return status, subject, body, "", reflection_display
|
| 165 |
+
else:
|
| 166 |
+
return status, "", "", status, "❌ 生成失敗,無法進行反思評估"
|
| 167 |
+
except Exception as e:
|
| 168 |
+
error_msg = f"❌ 發生錯誤:{str(e)}"
|
| 169 |
+
print(f"Email Tool 錯誤:{e}")
|
| 170 |
+
import traceback
|
| 171 |
+
traceback.print_exc()
|
| 172 |
+
return "❌ 發生錯誤", "", "", error_msg, f"❌ 發生錯誤:{str(e)}"
|
| 173 |
+
|
| 174 |
+
def send_draft(recipient, subject, body):
|
| 175 |
+
"""發送已編輯的郵件草稿"""
|
| 176 |
+
if not recipient or not recipient.strip():
|
| 177 |
+
return "❌ 請輸入收件人郵箱", "❌ 請輸入收件人郵箱"
|
| 178 |
+
|
| 179 |
+
if not subject or not subject.strip():
|
| 180 |
+
return "❌ 請輸入郵件主題", "❌ 請輸入郵件主題"
|
| 181 |
+
|
| 182 |
+
if not body or not body.strip():
|
| 183 |
+
return "❌ 請輸入郵件內容", "❌ 請輸入郵件內容"
|
| 184 |
+
|
| 185 |
+
# 驗證郵箱格式和 Gmail 限制
|
| 186 |
+
if "@" not in recipient or "." not in recipient.split("@")[1]:
|
| 187 |
+
return "❌ 郵箱格式不正確", "❌ 郵箱格式不正確,請輸入有效的郵箱地址"
|
| 188 |
+
|
| 189 |
+
# 驗證是否為 Gmail 郵箱
|
| 190 |
+
recipient_lower = recipient.strip().lower()
|
| 191 |
+
if not (recipient_lower.endswith("@gmail.com") or recipient_lower.endswith("@googlemail.com")):
|
| 192 |
+
return "❌ 僅支援 Gmail 郵箱", "❌ 此工具僅支援 Gmail 郵箱(@gmail.com 或 @googlemail.com),請輸入 Gmail 郵箱地址"
|
| 193 |
+
|
| 194 |
+
try:
|
| 195 |
+
status_msg = "🔄 正在發送郵件..."
|
| 196 |
+
|
| 197 |
+
# 發送郵件
|
| 198 |
+
result = send_email_draft(recipient.strip(), subject.strip(), body.strip())
|
| 199 |
+
|
| 200 |
+
return "✅ 郵件已發送", result
|
| 201 |
+
except Exception as e:
|
| 202 |
+
error_msg = f"❌ 發送郵件時發生錯誤:{str(e)}"
|
| 203 |
+
print(f"Email Tool 錯誤:{e}")
|
| 204 |
+
import traceback
|
| 205 |
+
traceback.print_exc()
|
| 206 |
+
return "❌ 發生錯誤", error_msg
|
| 207 |
+
|
| 208 |
+
def clear_email():
|
| 209 |
+
"""清除郵件相關輸入和輸出"""
|
| 210 |
+
return "", "", "等待操作...", "", "", "等待生成郵件..."
|
| 211 |
+
|
| 212 |
+
# 綁定事件
|
| 213 |
+
generate_draft_btn.click(
|
| 214 |
+
fn=generate_draft,
|
| 215 |
+
inputs=[email_prompt_input, recipient_input],
|
| 216 |
+
outputs=[email_status_display, email_subject_input, email_body_input, email_result_display, email_reflection_display]
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
send_draft_btn.click(
|
| 220 |
+
fn=send_draft,
|
| 221 |
+
inputs=[recipient_input, email_subject_input, email_body_input],
|
| 222 |
+
outputs=[email_status_display, email_result_display]
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
clear_email_btn.click(
|
| 226 |
+
fn=clear_email,
|
| 227 |
+
outputs=[email_prompt_input, recipient_input, email_status_display, email_subject_input, email_body_input, email_result_display, email_reflection_display]
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# 示例
|
| 231 |
+
gr.Examples(
|
| 232 |
+
examples=[
|
| 233 |
+
["寫一封感謝信,感謝對方在項目中的幫助和支持", "example@gmail.com"],
|
| 234 |
+
["邀請參加下週的產品發布會", "colleague@gmail.com"],
|
| 235 |
+
["詢問項目進度並提供更新", "partner@gmail.com"],
|
| 236 |
+
["發送會議記錄和後續行動項目", "team@gmail.com"]
|
| 237 |
+
],
|
| 238 |
+
inputs=[email_prompt_input, recipient_input]
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
# 頁腳說明
|
| 242 |
+
gr.Markdown(
|
| 243 |
+
f"""
|
| 244 |
+
---
|
| 245 |
+
**注意事項:**
|
| 246 |
+
1. 使用 Gmail API 發送郵件,避免被歸類為垃圾郵件
|
| 247 |
+
2. **此工具僅支援 Gmail 郵箱,收件人必須使用 @gmail.com 或 @googlemail.com 結尾的郵箱地址**
|
| 248 |
+
3. 首次使用需要在專案根目錄放置 `credentials.json`(從 Google Cloud Console 下載的 OAuth2 憑證)
|
| 249 |
+
4. 首次運行時會自動開啟瀏覽器進行授權,授權後會生成 `token.json` 文件
|
| 250 |
+
5. 郵件內容由 AI 自動生成,請在發送前檢查結果
|
| 251 |
+
6. 寄件者固定為:{EMAIL_SENDER}
|
| 252 |
+
|
| 253 |
+
**設置步驟:**
|
| 254 |
+
- 前往 [Google Cloud Console](https://console.cloud.google.com/) 創建專案
|
| 255 |
+
- 啟用 Gmail API
|
| 256 |
+
- 創建 OAuth2 憑證並下載為 `credentials.json`
|
| 257 |
+
- 將 `credentials.json` 放在專案根目錄
|
| 258 |
+
"""
|
| 259 |
+
)
|
deep_agent_rag/ui/gradio_interface.py
CHANGED
|
@@ -5,12 +5,17 @@ Gradio 界面模組
|
|
| 5 |
import uuid
|
| 6 |
import re
|
| 7 |
import time
|
|
|
|
|
|
|
| 8 |
from typing import Iterator, Tuple
|
| 9 |
import gradio as gr
|
| 10 |
from langchain_core.messages import HumanMessage
|
| 11 |
|
| 12 |
# graph 和 rag_retriever 將從外部傳入,不在這裡導入
|
| 13 |
from ..utils.llm_utils import get_llm_type, is_using_local_llm
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
def run_research_agent(query: str, graph, thread_id: str = None) -> Iterator[Tuple[str, str, str, str, str]]:
|
|
@@ -204,9 +209,9 @@ def create_gradio_interface(graph):
|
|
| 204 |
gr.Markdown(
|
| 205 |
"""
|
| 206 |
<div class="header">
|
| 207 |
-
<h1>🚀 Deep Research Agent with RAG
|
| 208 |
<p><strong>功能特色:</strong></p>
|
| 209 |
-
<p>📊 股票資訊查詢 | 🌐 網路搜尋 | 📚 PDF 知識庫查詢(Tree of Thoughts 論文)| 📧 智能郵件助手 | 📅 智能行事曆管理</p>
|
| 210 |
<p><strong>智能規劃:</strong> 系統會根據問題類型自動選擇合適的研究工具</p>
|
| 211 |
<p><strong>本地模型:</strong> 使用 MLX 本地模型,保護隱私,無需 API 金鑰</p>
|
| 212 |
</div>
|
|
@@ -227,6 +232,10 @@ def create_gradio_interface(graph):
|
|
| 227 |
# Tab 3: Calendar Tool
|
| 228 |
with gr.Tab("📅 Calendar Tool"):
|
| 229 |
_create_calendar_interface()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
|
| 231 |
return demo
|
| 232 |
|
|
@@ -363,895 +372,8 @@ def _create_research_interface(graph):
|
|
| 363 |
)
|
| 364 |
|
| 365 |
|
| 366 |
-
def _create_email_interface():
|
| 367 |
-
"""創建 Email Tool 界面"""
|
| 368 |
-
from ..agents.email_agent import generate_email_draft, send_email_draft
|
| 369 |
-
from ..config import EMAIL_SENDER
|
| 370 |
-
|
| 371 |
-
gr.Markdown(
|
| 372 |
-
f"""
|
| 373 |
-
### 📧 智能郵件助手
|
| 374 |
-
|
| 375 |
-
使用 AI 根據您的關鍵提示自動生成專業郵件草稿,您可以在發送前檢查和修改。
|
| 376 |
-
|
| 377 |
-
**寄件者:** {EMAIL_SENDER}
|
| 378 |
-
|
| 379 |
-
**使用方式:**
|
| 380 |
-
1. 在下方輸入郵件提示(例如:"寫一封感謝信"、"邀請參加會議"等)
|
| 381 |
-
2. 輸入收件人 Gmail 郵箱地址(僅支援 @gmail.com 或 @googlemail.com)
|
| 382 |
-
3. 點擊「生成郵件草稿」按鈕
|
| 383 |
-
4. 查看 AI 反思評估結果和改進建議(如有)
|
| 384 |
-
5. 檢查並修改生成的郵件內容(特別是簽名部分)
|
| 385 |
-
6. 確認無誤後點擊「發送郵件」按鈕
|
| 386 |
-
|
| 387 |
-
**✨ 新功能:AI 迭代反思評估**
|
| 388 |
-
- 系統會自動進行多輪反思評估(最多 3 輪)
|
| 389 |
-
- 每輪評估後,如果有改進建議,會自動生成改進版本
|
| 390 |
-
- 改進後的版本會再次評估,直到 AI 認為滿意為止
|
| 391 |
-
- 您可以看到完整的反思過程和每輪的改進建議
|
| 392 |
-
|
| 393 |
-
**注意:此工具僅支援 Gmail 郵箱,收件人必須使用 Gmail 郵箱地址。**
|
| 394 |
-
"""
|
| 395 |
-
)
|
| 396 |
-
|
| 397 |
-
with gr.Row():
|
| 398 |
-
with gr.Column(scale=1):
|
| 399 |
-
# 郵件提示輸入
|
| 400 |
-
email_prompt_input = gr.Textbox(
|
| 401 |
-
label="📝 郵件提示",
|
| 402 |
-
placeholder="例如:寫一封感謝信,感謝對方在項目中的幫助",
|
| 403 |
-
lines=5,
|
| 404 |
-
value="寫一封專業的郵件,介紹我們的 AI 產品"
|
| 405 |
-
)
|
| 406 |
-
|
| 407 |
-
# 收件人輸入
|
| 408 |
-
recipient_input = gr.Textbox(
|
| 409 |
-
label="📮 收件人郵箱(僅支援 Gmail)",
|
| 410 |
-
placeholder="recipient@gmail.com",
|
| 411 |
-
lines=1
|
| 412 |
-
)
|
| 413 |
-
|
| 414 |
-
# 按鈕
|
| 415 |
-
with gr.Row():
|
| 416 |
-
generate_draft_btn = gr.Button("📝 生成郵件草稿", variant="primary", scale=1)
|
| 417 |
-
clear_email_btn = gr.Button("🗑️ 清除", variant="secondary", scale=1)
|
| 418 |
-
|
| 419 |
-
# 狀態顯示
|
| 420 |
-
email_status_display = gr.Textbox(
|
| 421 |
-
label="📊 狀態",
|
| 422 |
-
value="等待操作...",
|
| 423 |
-
interactive=False,
|
| 424 |
-
lines=2
|
| 425 |
-
)
|
| 426 |
-
|
| 427 |
-
# 反思結果顯示
|
| 428 |
-
email_reflection_display = gr.Textbox(
|
| 429 |
-
label="🔍 AI 反思評估",
|
| 430 |
-
value="等待生成郵件...",
|
| 431 |
-
interactive=False,
|
| 432 |
-
lines=8,
|
| 433 |
-
visible=True
|
| 434 |
-
)
|
| 435 |
-
|
| 436 |
-
with gr.Column(scale=1):
|
| 437 |
-
# 郵件主題(可編輯)
|
| 438 |
-
email_subject_input = gr.Textbox(
|
| 439 |
-
label="📌 郵件主題",
|
| 440 |
-
placeholder="郵件主題將在這裡顯示���您可以編輯",
|
| 441 |
-
lines=1,
|
| 442 |
-
interactive=True
|
| 443 |
-
)
|
| 444 |
-
|
| 445 |
-
# 郵件正文(可編輯)
|
| 446 |
-
email_body_input = gr.Textbox(
|
| 447 |
-
label="📄 郵件正文(可編輯)",
|
| 448 |
-
placeholder="郵件內容將在這裡顯示,您可以編輯",
|
| 449 |
-
lines=15,
|
| 450 |
-
interactive=True
|
| 451 |
-
)
|
| 452 |
-
|
| 453 |
-
# 發送按鈕
|
| 454 |
-
send_draft_btn = gr.Button("📧 發送郵件", variant="primary", scale=1)
|
| 455 |
-
|
| 456 |
-
# 發送結果顯示
|
| 457 |
-
email_result_display = gr.Textbox(
|
| 458 |
-
label="📊 發送結果",
|
| 459 |
-
lines=5,
|
| 460 |
-
interactive=False
|
| 461 |
-
)
|
| 462 |
-
|
| 463 |
-
# 事件處理函數
|
| 464 |
-
def generate_draft(prompt, recipient):
|
| 465 |
-
"""生成郵件草稿(包含反思功能)"""
|
| 466 |
-
if not prompt or not prompt.strip():
|
| 467 |
-
return "❌ 請輸入郵件提示", "", "", "❌ 請輸入郵件提示", "❌ 請輸入郵件提示"
|
| 468 |
-
|
| 469 |
-
if not recipient or not recipient.strip():
|
| 470 |
-
return "❌ 請輸入收件人郵箱", "", "", "❌ 請輸入收件人郵箱", "❌ 請輸入收件人郵箱"
|
| 471 |
-
|
| 472 |
-
# 驗證郵箱格式和 Gmail 限制
|
| 473 |
-
if "@" not in recipient or "." not in recipient.split("@")[1]:
|
| 474 |
-
return "❌ 郵箱格式不正確", "", "", "❌ 郵箱格式不正確,請輸入有效的郵箱地址", "❌ 郵箱格式不正確,請輸入有效的郵箱地址"
|
| 475 |
-
|
| 476 |
-
# 驗證是否為 Gmail 郵箱
|
| 477 |
-
recipient_lower = recipient.strip().lower()
|
| 478 |
-
if not (recipient_lower.endswith("@gmail.com") or recipient_lower.endswith("@googlemail.com")):
|
| 479 |
-
return "❌ 僅支援 Gmail 郵箱", "", "", "❌ 此工具僅支援 Gmail 郵箱(@gmail.com 或 @googlemail.com),請輸入 Gmail 郵箱地址", "❌ 此工具僅支援 Gmail 郵箱(@gmail.com 或 @googlemail.com),請輸入 Gmail 郵箱地址"
|
| 480 |
-
|
| 481 |
-
try:
|
| 482 |
-
status_msg = "🔄 正在生成郵件草稿..."
|
| 483 |
-
reflection_msg = "🔄 正在生成郵件草稿..."
|
| 484 |
-
|
| 485 |
-
# 生成郵件草稿(包含反思功能,會自動改進)
|
| 486 |
-
subject, body, status, reflection_result, was_improved = generate_email_draft(
|
| 487 |
-
prompt, recipient.strip(), enable_reflection=True
|
| 488 |
-
)
|
| 489 |
-
|
| 490 |
-
if subject and body:
|
| 491 |
-
# 格式化反思結果顯示
|
| 492 |
-
if reflection_result:
|
| 493 |
-
# 計算反思輪數
|
| 494 |
-
reflection_count = reflection_result.count("【第") if "【第" in reflection_result else 0
|
| 495 |
-
|
| 496 |
-
if was_improved:
|
| 497 |
-
if reflection_count > 1:
|
| 498 |
-
reflection_display = (
|
| 499 |
-
f"🔍 **AI 迭代反思評估結果**(共 {reflection_count} 輪)\n\n"
|
| 500 |
-
f"{reflection_result}\n\n"
|
| 501 |
-
f"✨ **已自動應用改進建議,經過 {reflection_count} 輪優化,當前顯示的是最終優化版本**"
|
| 502 |
-
)
|
| 503 |
-
else:
|
| 504 |
-
reflection_display = (
|
| 505 |
-
f"🔍 **AI 反思評估結果**\n\n"
|
| 506 |
-
f"{reflection_result}\n\n"
|
| 507 |
-
f"✨ **已自動應用改進建議,當前顯示的是優化後的版本**"
|
| 508 |
-
)
|
| 509 |
-
else:
|
| 510 |
-
reflection_display = (
|
| 511 |
-
f"🔍 **AI 反思評估結果**\n\n"
|
| 512 |
-
f"{reflection_result}\n\n"
|
| 513 |
-
f"✅ **郵件質量良好,無需改進**"
|
| 514 |
-
)
|
| 515 |
-
else:
|
| 516 |
-
reflection_display = "⚠️ 反思功能未返回結果"
|
| 517 |
-
|
| 518 |
-
return status, subject, body, "", reflection_display
|
| 519 |
-
else:
|
| 520 |
-
return status, "", "", status, "❌ 生成失敗,無法進行反思評估"
|
| 521 |
-
except Exception as e:
|
| 522 |
-
error_msg = f"❌ 發生錯誤:{str(e)}"
|
| 523 |
-
print(f"Email Tool 錯誤:{e}")
|
| 524 |
-
import traceback
|
| 525 |
-
traceback.print_exc()
|
| 526 |
-
return "❌ 發生錯誤", "", "", error_msg, f"❌ 發生錯誤:{str(e)}"
|
| 527 |
-
|
| 528 |
-
def send_draft(recipient, subject, body):
|
| 529 |
-
"""發送已編輯的郵件草稿"""
|
| 530 |
-
if not recipient or not recipient.strip():
|
| 531 |
-
return "❌ 請輸入收件人郵箱", "❌ 請輸入收件人郵箱"
|
| 532 |
-
|
| 533 |
-
if not subject or not subject.strip():
|
| 534 |
-
return "❌ 請輸入郵件主題", "❌ 請輸入郵件主題"
|
| 535 |
-
|
| 536 |
-
if not body or not body.strip():
|
| 537 |
-
return "❌ 請輸入郵件內容", "❌ 請輸入郵件內容"
|
| 538 |
-
|
| 539 |
-
# 驗證郵箱格式和 Gmail 限制
|
| 540 |
-
if "@" not in recipient or "." not in recipient.split("@")[1]:
|
| 541 |
-
return "❌ 郵箱格式不正確", "❌ 郵箱格式不正確,請輸入有效的郵箱地址"
|
| 542 |
-
|
| 543 |
-
# 驗證是否為 Gmail 郵箱
|
| 544 |
-
recipient_lower = recipient.strip().lower()
|
| 545 |
-
if not (recipient_lower.endswith("@gmail.com") or recipient_lower.endswith("@googlemail.com")):
|
| 546 |
-
return "❌ 僅支援 Gmail 郵箱", "❌ 此工具僅支援 Gmail 郵箱(@gmail.com 或 @googlemail.com),請輸入 Gmail 郵箱地址"
|
| 547 |
-
|
| 548 |
-
try:
|
| 549 |
-
status_msg = "🔄 正在發送郵件..."
|
| 550 |
-
|
| 551 |
-
# 發送郵件
|
| 552 |
-
result = send_email_draft(recipient.strip(), subject.strip(), body.strip())
|
| 553 |
-
|
| 554 |
-
return "✅ 郵件已發送", result
|
| 555 |
-
except Exception as e:
|
| 556 |
-
error_msg = f"❌ 發送郵件時發生錯誤:{str(e)}"
|
| 557 |
-
print(f"Email Tool 錯誤:{e}")
|
| 558 |
-
import traceback
|
| 559 |
-
traceback.print_exc()
|
| 560 |
-
return "❌ 發生錯誤", error_msg
|
| 561 |
-
|
| 562 |
-
def clear_email():
|
| 563 |
-
"""清除郵件相關輸入和輸出"""
|
| 564 |
-
return "", "", "等待操作...", "", "", "等待生成郵件..."
|
| 565 |
-
|
| 566 |
-
# 綁定事件
|
| 567 |
-
generate_draft_btn.click(
|
| 568 |
-
fn=generate_draft,
|
| 569 |
-
inputs=[email_prompt_input, recipient_input],
|
| 570 |
-
outputs=[email_status_display, email_subject_input, email_body_input, email_result_display, email_reflection_display]
|
| 571 |
-
)
|
| 572 |
-
|
| 573 |
-
send_draft_btn.click(
|
| 574 |
-
fn=send_draft,
|
| 575 |
-
inputs=[recipient_input, email_subject_input, email_body_input],
|
| 576 |
-
outputs=[email_status_display, email_result_display]
|
| 577 |
-
)
|
| 578 |
-
|
| 579 |
-
clear_email_btn.click(
|
| 580 |
-
fn=clear_email,
|
| 581 |
-
outputs=[email_prompt_input, recipient_input, email_status_display, email_subject_input, email_body_input, email_result_display, email_reflection_display]
|
| 582 |
-
)
|
| 583 |
-
|
| 584 |
-
# 示例
|
| 585 |
-
gr.Examples(
|
| 586 |
-
examples=[
|
| 587 |
-
["寫一封感謝信,感謝對方在項目中的幫助和支持", "example@gmail.com"],
|
| 588 |
-
["邀請參加下週的產品發布會", "colleague@gmail.com"],
|
| 589 |
-
["詢問項目進度並提供更新", "partner@gmail.com"],
|
| 590 |
-
["發送會議記錄和後續行動項目", "team@gmail.com"]
|
| 591 |
-
],
|
| 592 |
-
inputs=[email_prompt_input, recipient_input]
|
| 593 |
-
)
|
| 594 |
-
|
| 595 |
-
# 頁腳說明
|
| 596 |
-
gr.Markdown(
|
| 597 |
-
f"""
|
| 598 |
-
---
|
| 599 |
-
**注意事項:**
|
| 600 |
-
1. 使用 Gmail API 發送郵件,避免被歸類為垃圾郵件
|
| 601 |
-
2. **此工具僅支援 Gmail 郵箱,收件人必須使用 @gmail.com 或 @googlemail.com 結尾的郵箱地址**
|
| 602 |
-
3. 首次使用需要在專案根目錄放置 `credentials.json`(從 Google Cloud Console 下載的 OAuth2 憑證)
|
| 603 |
-
4. 首次運行時會自動開啟瀏覽器進行授權,授權後會生成 `token.json` 文件
|
| 604 |
-
5. 郵件內容由 AI 自動生成,請在發送前檢查結果
|
| 605 |
-
6. 寄件者固定為:{EMAIL_SENDER}
|
| 606 |
-
|
| 607 |
-
**設置步驟:**
|
| 608 |
-
- 前往 [Google Cloud Console](https://console.cloud.google.com/) 創建專案
|
| 609 |
-
- 啟用 Gmail API
|
| 610 |
-
- 創建 OAuth2 憑證並下載為 `credentials.json`
|
| 611 |
-
- 將 `credentials.json` 放在專案根目錄
|
| 612 |
-
"""
|
| 613 |
-
)
|
| 614 |
|
| 615 |
|
| 616 |
-
def _create_calendar_interface():
|
| 617 |
-
"""創建 Calendar Tool 界面"""
|
| 618 |
-
from ..agents.calendar_agent import generate_calendar_draft, create_calendar_draft
|
| 619 |
-
from datetime import datetime, timedelta
|
| 620 |
-
|
| 621 |
-
gr.Markdown(
|
| 622 |
-
"""
|
| 623 |
-
### 📅 智能行事曆管理助手
|
| 624 |
-
|
| 625 |
-
使用 AI 根據您的完整提示自動生成行事曆事件草稿,您可以在創建前檢查和修改。
|
| 626 |
-
|
| 627 |
-
**使用方式:**
|
| 628 |
-
1. **快速選擇**:點擊下方常見事件按鈕,自動生成草稿
|
| 629 |
-
2. **自定義輸入**:在下方輸入完整的事件提示,包含:事件、日期、時間、地點、參與者
|
| 630 |
-
3. 查看 AI 反思評估結果和改進建議(如有)
|
| 631 |
-
4. 如果有缺失的資訊(如時間),系統會顯示下拉選單讓您選擇
|
| 632 |
-
5. 檢查並修改生成的事件內容
|
| 633 |
-
6. 確認無誤後點擊「創建事件」按鈕
|
| 634 |
-
|
| 635 |
-
**✨ 新功能:AI 迭代反思評估 + Google Maps 地點驗證**
|
| 636 |
-
- 系統會自動進行多輪反思評估(最多 3 輪)
|
| 637 |
-
- 自動驗證並標準化地址,計算交通時間
|
| 638 |
-
- 每輪評估後,如果有改進建議,會自動生成改進版本
|
| 639 |
-
- 改進後的版本會再次評估,直到 AI 認為滿意為止
|
| 640 |
-
"""
|
| 641 |
-
)
|
| 642 |
-
|
| 643 |
-
# 快速選擇按鈕區域
|
| 644 |
-
gr.Markdown("### 🚀 快速選擇常見事件")
|
| 645 |
-
with gr.Row():
|
| 646 |
-
quick_meeting_btn = gr.Button("📋 團隊會議", variant="secondary", scale=1)
|
| 647 |
-
quick_client_btn = gr.Button("🤝 客戶拜訪", variant="secondary", scale=1)
|
| 648 |
-
quick_lunch_btn = gr.Button("🍽️ 午餐會議", variant="secondary", scale=1)
|
| 649 |
-
quick_oneonone_btn = gr.Button("💬 一對一會議", variant="secondary", scale=1)
|
| 650 |
-
with gr.Row():
|
| 651 |
-
quick_project_btn = gr.Button("📊 項目討論", variant="secondary", scale=1)
|
| 652 |
-
quick_training_btn = gr.Button("🎓 培訓/學習", variant="secondary", scale=1)
|
| 653 |
-
quick_social_btn = gr.Button("🎉 社交活動", variant="secondary", scale=1)
|
| 654 |
-
quick_custom_btn = gr.Button("✏️ 自定義輸入", variant="secondary", scale=1)
|
| 655 |
-
|
| 656 |
-
with gr.Row():
|
| 657 |
-
with gr.Column(scale=1):
|
| 658 |
-
# 單一 prompt 輸入
|
| 659 |
-
calendar_prompt_input = gr.Textbox(
|
| 660 |
-
label="📝 事件提示(包含事件、日期、時間、地點、參與者)",
|
| 661 |
-
placeholder="例如:明天下午2點團隊會議,討論項目進度,地點在會議室A,參與者包括john@example.com和mary@example.com",
|
| 662 |
-
lines=5,
|
| 663 |
-
value=""
|
| 664 |
-
)
|
| 665 |
-
|
| 666 |
-
# 按鈕
|
| 667 |
-
with gr.Row():
|
| 668 |
-
generate_draft_btn = gr.Button("📝 生成事件草稿", variant="primary", scale=1)
|
| 669 |
-
clear_calendar_btn = gr.Button("🗑️ 清除", variant="secondary", scale=1)
|
| 670 |
-
|
| 671 |
-
# 狀態顯示
|
| 672 |
-
calendar_status_display = gr.Textbox(
|
| 673 |
-
label="📊 狀態",
|
| 674 |
-
value="等待操作...",
|
| 675 |
-
interactive=False,
|
| 676 |
-
lines=2
|
| 677 |
-
)
|
| 678 |
-
|
| 679 |
-
# 反思結果顯示
|
| 680 |
-
calendar_reflection_display = gr.Textbox(
|
| 681 |
-
label="🔍 AI 反思評估",
|
| 682 |
-
value="等待生成事件...",
|
| 683 |
-
interactive=False,
|
| 684 |
-
lines=8,
|
| 685 |
-
visible=True
|
| 686 |
-
)
|
| 687 |
-
|
| 688 |
-
# 缺失資訊的補充區域(動態顯示)
|
| 689 |
-
missing_info_group = gr.Group(visible=False)
|
| 690 |
-
with missing_info_group:
|
| 691 |
-
gr.Markdown("**⚠️ 請補充以下缺失的資訊:**")
|
| 692 |
-
|
| 693 |
-
# 日期選擇(如果缺失)
|
| 694 |
-
missing_date_display = gr.Dropdown(
|
| 695 |
-
label="📆 選擇日期",
|
| 696 |
-
choices=[],
|
| 697 |
-
visible=False,
|
| 698 |
-
interactive=True
|
| 699 |
-
)
|
| 700 |
-
|
| 701 |
-
# 時間選擇(如果缺失)
|
| 702 |
-
missing_time_display = gr.Dropdown(
|
| 703 |
-
label="🕐 選擇時間",
|
| 704 |
-
choices=[],
|
| 705 |
-
visible=False,
|
| 706 |
-
interactive=True
|
| 707 |
-
)
|
| 708 |
-
|
| 709 |
-
fill_missing_btn = gr.Button("✅ 確認補充資訊", variant="primary", visible=False)
|
| 710 |
-
|
| 711 |
-
# 隱藏狀態變數,用於存儲 event_dict
|
| 712 |
-
event_dict_storage = gr.State(value={})
|
| 713 |
-
|
| 714 |
-
with gr.Column(scale=1):
|
| 715 |
-
# 事件詳情顯示和編輯區域
|
| 716 |
-
event_summary_display = gr.Textbox(
|
| 717 |
-
label="📌 事件標題",
|
| 718 |
-
placeholder="事件標題將在這裡顯示",
|
| 719 |
-
lines=1,
|
| 720 |
-
interactive=True
|
| 721 |
-
)
|
| 722 |
-
|
| 723 |
-
event_start_display = gr.Textbox(
|
| 724 |
-
label="🕐 開始時間",
|
| 725 |
-
placeholder="開始時間將在這裡顯示(格式: YYYY-MM-DDTHH:MM:SS+08:00)",
|
| 726 |
-
lines=1,
|
| 727 |
-
interactive=True
|
| 728 |
-
)
|
| 729 |
-
|
| 730 |
-
event_end_display = gr.Textbox(
|
| 731 |
-
label="🕐 結束時間",
|
| 732 |
-
placeholder="結束時間將在這裡顯示(格式: YYYY-MM-DDTHH:MM:SS+08:00)",
|
| 733 |
-
lines=1,
|
| 734 |
-
interactive=True
|
| 735 |
-
)
|
| 736 |
-
|
| 737 |
-
event_description_display = gr.Textbox(
|
| 738 |
-
label="📄 事件描述(可編輯)",
|
| 739 |
-
placeholder="事件描述將在這裡顯示,您可以編輯",
|
| 740 |
-
lines=6,
|
| 741 |
-
interactive=True
|
| 742 |
-
)
|
| 743 |
-
|
| 744 |
-
event_location_display = gr.Textbox(
|
| 745 |
-
label="📍 地點(可編輯,已自動驗證並標準化)",
|
| 746 |
-
placeholder="事件地點將在這裡顯示,您可以編輯",
|
| 747 |
-
lines=2,
|
| 748 |
-
interactive=True
|
| 749 |
-
)
|
| 750 |
-
|
| 751 |
-
event_attendees_display = gr.Textbox(
|
| 752 |
-
label="👥 參與者郵箱(可編輯,多個用逗號分隔)",
|
| 753 |
-
placeholder="參與者郵箱將在這裡顯示,您可以編輯",
|
| 754 |
-
lines=1,
|
| 755 |
-
interactive=True
|
| 756 |
-
)
|
| 757 |
-
|
| 758 |
-
# 創建按鈕
|
| 759 |
-
create_event_btn = gr.Button("✅ 創建事件", variant="primary", scale=1)
|
| 760 |
-
|
| 761 |
-
# 操作結果顯示
|
| 762 |
-
calendar_result_display = gr.Textbox(
|
| 763 |
-
label="📊 操作結果",
|
| 764 |
-
lines=8,
|
| 765 |
-
interactive=False
|
| 766 |
-
)
|
| 767 |
-
|
| 768 |
-
# 生成時間選項(每30分鐘一個選項)
|
| 769 |
-
def generate_time_options():
|
| 770 |
-
"""生成時間選項列表"""
|
| 771 |
-
times = []
|
| 772 |
-
for hour in range(24):
|
| 773 |
-
for minute in [0, 30]:
|
| 774 |
-
time_str = f"{hour:02d}:{minute:02d}"
|
| 775 |
-
times.append(time_str)
|
| 776 |
-
return times
|
| 777 |
-
|
| 778 |
-
# 生成日期選項(今天、明天、後天,以及未來7天)
|
| 779 |
-
def generate_date_options():
|
| 780 |
-
"""生成日期選項列表"""
|
| 781 |
-
dates = []
|
| 782 |
-
today = datetime.now()
|
| 783 |
-
date_names = ["今天", "明天", "後天"]
|
| 784 |
-
|
| 785 |
-
for i in range(3):
|
| 786 |
-
date_obj = today + timedelta(days=i)
|
| 787 |
-
date_str = date_obj.strftime('%Y-%m-%d')
|
| 788 |
-
dates.append(f"{date_names[i]} ({date_str})")
|
| 789 |
-
|
| 790 |
-
for i in range(3, 7):
|
| 791 |
-
date_obj = today + timedelta(days=i)
|
| 792 |
-
date_str = date_obj.strftime('%Y-%m-%d')
|
| 793 |
-
dates.append(date_str)
|
| 794 |
-
|
| 795 |
-
return dates
|
| 796 |
-
|
| 797 |
-
# 快速選擇事件模板生成函數
|
| 798 |
-
def generate_quick_prompt(event_type: str) -> str:
|
| 799 |
-
"""根據事件類型生成預設提示"""
|
| 800 |
-
from datetime import datetime, timedelta
|
| 801 |
-
|
| 802 |
-
# 獲取明天的日期
|
| 803 |
-
tomorrow = datetime.now() + timedelta(days=1)
|
| 804 |
-
tomorrow_str = tomorrow.strftime("%Y-%m-%d")
|
| 805 |
-
|
| 806 |
-
templates = {
|
| 807 |
-
"meeting": f"明天下午2點團隊會議,討論項目進度和下週計劃,地點在會議室,參與者包括團隊成員",
|
| 808 |
-
"client": f"明天上午10點客戶拜訪,討論合作方案和需求,地點在客戶公司或會議室",
|
| 809 |
-
"lunch": f"明天中午12點午餐會議,與合作夥伴討論業務合作,地點在附近的餐廳",
|
| 810 |
-
"oneonone": f"明天下午3點一對一會議,討論工作進展和職業發展,地點在會議室或咖啡廳",
|
| 811 |
-
"project": f"明天上午9點項目討論會議,審查項目進度和解決問題,地點在項目室,參與者包括項目團隊",
|
| 812 |
-
"training": f"明天下午2點培訓課程,學習新技能和最佳實踐,地點在培訓室或線上",
|
| 813 |
-
"social": f"明天晚上6點團隊聚餐,慶祝項目完成,地點在餐廳,參與者包括團隊成員",
|
| 814 |
-
"custom": "" # 自定義,返回空讓用戶輸入
|
| 815 |
-
}
|
| 816 |
-
|
| 817 |
-
return templates.get(event_type, "")
|
| 818 |
-
|
| 819 |
-
# 快速選擇按鈕處理函數(自動生成草稿)
|
| 820 |
-
def quick_select_and_generate(event_type: str):
|
| 821 |
-
"""快速選擇事件類型並自動生成草稿"""
|
| 822 |
-
prompt = generate_quick_prompt(event_type)
|
| 823 |
-
if not prompt:
|
| 824 |
-
# 如果是自定義,只返回空提示,不自動生成
|
| 825 |
-
return (
|
| 826 |
-
prompt, # calendar_prompt_input
|
| 827 |
-
"請在下方輸入框中輸入事件提示,然後點擊「生成事件草稿」", # calendar_status_display
|
| 828 |
-
"等待輸入...", # calendar_reflection_display
|
| 829 |
-
gr.update(visible=False), # missing_info_group
|
| 830 |
-
gr.update(visible=False, choices=[]), # missing_date_display
|
| 831 |
-
gr.update(visible=False, choices=[]), # missing_time_display
|
| 832 |
-
gr.update(visible=False), # fill_missing_btn
|
| 833 |
-
"", "", "", "", "", "", # event fields
|
| 834 |
-
{}, # event_dict_storage
|
| 835 |
-
"" # calendar_result_display
|
| 836 |
-
)
|
| 837 |
-
|
| 838 |
-
# 自動生成草稿(調用 generate_draft 並返回所有輸出)
|
| 839 |
-
draft_result = generate_draft(prompt)
|
| 840 |
-
# generate_draft 返回的格式是:(status, reflection_display, missing_info_group, ...)
|
| 841 |
-
# 但我們需要返回 (prompt, status, reflection_display, ...)
|
| 842 |
-
# draft_result 是一個元組,我們需要將 prompt 添加到開頭
|
| 843 |
-
return (prompt,) + draft_result
|
| 844 |
-
|
| 845 |
-
def quick_select_meeting():
|
| 846 |
-
"""快速選擇:團隊會議"""
|
| 847 |
-
return quick_select_and_generate("meeting")
|
| 848 |
-
|
| 849 |
-
def quick_select_client():
|
| 850 |
-
"""快速選擇:客戶拜訪"""
|
| 851 |
-
return quick_select_and_generate("client")
|
| 852 |
-
|
| 853 |
-
def quick_select_lunch():
|
| 854 |
-
"""快速選擇:午餐會議"""
|
| 855 |
-
return quick_select_and_generate("lunch")
|
| 856 |
-
|
| 857 |
-
def quick_select_oneonone():
|
| 858 |
-
"""快速選擇:一對一會議"""
|
| 859 |
-
return quick_select_and_generate("oneonone")
|
| 860 |
-
|
| 861 |
-
def quick_select_project():
|
| 862 |
-
"""快速選擇:項目討論"""
|
| 863 |
-
return quick_select_and_generate("project")
|
| 864 |
-
|
| 865 |
-
def quick_select_training():
|
| 866 |
-
"""快速選擇:培訓/學習"""
|
| 867 |
-
return quick_select_and_generate("training")
|
| 868 |
-
|
| 869 |
-
def quick_select_social():
|
| 870 |
-
"""快速選擇:社交活動"""
|
| 871 |
-
return quick_select_and_generate("social")
|
| 872 |
-
|
| 873 |
-
def quick_select_custom():
|
| 874 |
-
"""快速選擇:自定義輸入(只清空,不自動生成)"""
|
| 875 |
-
return (
|
| 876 |
-
"", # calendar_prompt_input
|
| 877 |
-
"請在下方輸入框中輸入事件提示,然後點擊「生成事件草稿」", # calendar_status_display
|
| 878 |
-
"等待輸入...", # calendar_reflection_display
|
| 879 |
-
gr.update(visible=False), # missing_info_group
|
| 880 |
-
gr.update(visible=False, choices=[]), # missing_date_display
|
| 881 |
-
gr.update(visible=False, choices=[]), # missing_time_display
|
| 882 |
-
gr.update(visible=False), # fill_missing_btn
|
| 883 |
-
"", "", "", "", "", "", # event fields
|
| 884 |
-
{}, # event_dict_storage
|
| 885 |
-
"" # calendar_result_display
|
| 886 |
-
)
|
| 887 |
-
|
| 888 |
-
# 事件處理函數
|
| 889 |
-
def generate_draft(prompt):
|
| 890 |
-
"""生成行事曆事件草稿(包含反思功能)"""
|
| 891 |
-
if not prompt or not prompt.strip():
|
| 892 |
-
return (
|
| 893 |
-
"❌ 請輸入事件提示",
|
| 894 |
-
"❌ 請輸入事件提示",
|
| 895 |
-
gr.update(visible=False),
|
| 896 |
-
gr.update(visible=False, choices=[]),
|
| 897 |
-
gr.update(visible=False, choices=[]),
|
| 898 |
-
gr.update(visible=False),
|
| 899 |
-
"", "", "", "", "", "", {},
|
| 900 |
-
"❌ 請輸入事件提示"
|
| 901 |
-
)
|
| 902 |
-
|
| 903 |
-
try:
|
| 904 |
-
status_msg = "🔄 正在生成事件草稿..."
|
| 905 |
-
|
| 906 |
-
# 生成事件草稿(包含反思功能)
|
| 907 |
-
event_dict, status, missing_info, reflection_result, was_improved = generate_calendar_draft(
|
| 908 |
-
prompt.strip(), enable_reflection=True
|
| 909 |
-
)
|
| 910 |
-
|
| 911 |
-
if not event_dict:
|
| 912 |
-
return (
|
| 913 |
-
status,
|
| 914 |
-
gr.update(visible=False),
|
| 915 |
-
gr.update(visible=False, choices=[]),
|
| 916 |
-
gr.update(visible=False, choices=[]),
|
| 917 |
-
gr.update(visible=False),
|
| 918 |
-
"", "", "", "", "", "", "",
|
| 919 |
-
status
|
| 920 |
-
)
|
| 921 |
-
|
| 922 |
-
# 格式化反思結果顯示
|
| 923 |
-
if reflection_result:
|
| 924 |
-
# 計算反思輪數
|
| 925 |
-
reflection_count = reflection_result.count("【第") if "【第" in reflection_result else 0
|
| 926 |
-
|
| 927 |
-
if was_improved:
|
| 928 |
-
if reflection_count > 1:
|
| 929 |
-
reflection_display = (
|
| 930 |
-
f"🔍 **AI 迭代反思評估結果**(共 {reflection_count} 輪)\n\n"
|
| 931 |
-
f"{reflection_result}\n\n"
|
| 932 |
-
f"✨ **已自動應用改進建議,經過 {reflection_count} 輪優化,當前顯示的是最終優化版本**"
|
| 933 |
-
)
|
| 934 |
-
else:
|
| 935 |
-
reflection_display = (
|
| 936 |
-
f"🔍 **AI 反思評估結果**\n\n"
|
| 937 |
-
f"{reflection_result}\n\n"
|
| 938 |
-
f"✨ **已自動應用改進建議,當前顯示的是優化後的版本**"
|
| 939 |
-
)
|
| 940 |
-
else:
|
| 941 |
-
reflection_display = (
|
| 942 |
-
f"🔍 **AI 反思評估結果**\n\n"
|
| 943 |
-
f"{reflection_result}\n\n"
|
| 944 |
-
f"✅ **事件質量良好,無需改進**"
|
| 945 |
-
)
|
| 946 |
-
else:
|
| 947 |
-
reflection_display = "⚠️ 反思功能未返回結果"
|
| 948 |
-
|
| 949 |
-
# 【Google Maps 整合】添加地點建議訊息
|
| 950 |
-
location_suggestion = event_dict.get("location_suggestion", "")
|
| 951 |
-
if location_suggestion:
|
| 952 |
-
# 將地點建議添加到狀態訊息中
|
| 953 |
-
if status:
|
| 954 |
-
status = f"{status}\n\n🗺️ **地點資訊:**\n{location_suggestion}"
|
| 955 |
-
else:
|
| 956 |
-
status = f"🗺️ **地點資訊:**\n{location_suggestion}"
|
| 957 |
-
|
| 958 |
-
# 檢查是否有缺失資訊
|
| 959 |
-
has_missing = bool(missing_info)
|
| 960 |
-
|
| 961 |
-
if has_missing:
|
| 962 |
-
# 顯示缺失資訊區域
|
| 963 |
-
date_visible = missing_info.get("date", False)
|
| 964 |
-
time_visible = missing_info.get("time", False)
|
| 965 |
-
|
| 966 |
-
date_choices = generate_date_options() if date_visible else []
|
| 967 |
-
time_choices = generate_time_options() if time_visible else []
|
| 968 |
-
|
| 969 |
-
return (
|
| 970 |
-
status,
|
| 971 |
-
reflection_display,
|
| 972 |
-
gr.update(visible=True), # 顯示缺失資訊區域
|
| 973 |
-
gr.update(visible=date_visible, choices=date_choices, value=date_choices[0] if date_choices else None),
|
| 974 |
-
gr.update(visible=time_visible, choices=time_choices, value=time_choices[0] if time_choices else None),
|
| 975 |
-
gr.update(visible=True), # 顯示確認按鈕
|
| 976 |
-
event_dict.get("summary", ""),
|
| 977 |
-
event_dict.get("start_datetime", ""),
|
| 978 |
-
event_dict.get("end_datetime", ""),
|
| 979 |
-
event_dict.get("description", ""),
|
| 980 |
-
event_dict.get("location", ""),
|
| 981 |
-
event_dict.get("attendees", ""),
|
| 982 |
-
event_dict, # 傳遞完整的事件字典以便後續使用
|
| 983 |
-
""
|
| 984 |
-
)
|
| 985 |
-
else:
|
| 986 |
-
# 沒有缺失資訊,直接顯示結果
|
| 987 |
-
return (
|
| 988 |
-
status,
|
| 989 |
-
reflection_display,
|
| 990 |
-
gr.update(visible=False),
|
| 991 |
-
gr.update(visible=False, choices=[]),
|
| 992 |
-
gr.update(visible=False, choices=[]),
|
| 993 |
-
gr.update(visible=False),
|
| 994 |
-
event_dict.get("summary", ""),
|
| 995 |
-
event_dict.get("start_datetime", ""),
|
| 996 |
-
event_dict.get("end_datetime", ""),
|
| 997 |
-
event_dict.get("description", ""),
|
| 998 |
-
event_dict.get("location", ""),
|
| 999 |
-
event_dict.get("attendees", ""),
|
| 1000 |
-
event_dict,
|
| 1001 |
-
""
|
| 1002 |
-
)
|
| 1003 |
-
except Exception as e:
|
| 1004 |
-
error_msg = f"❌ 發生錯誤:{str(e)}"
|
| 1005 |
-
print(f"Calendar Tool 錯誤:{e}")
|
| 1006 |
-
import traceback
|
| 1007 |
-
traceback.print_exc()
|
| 1008 |
-
return (
|
| 1009 |
-
"❌ 發生錯誤",
|
| 1010 |
-
f"❌ 發生錯誤:{str(e)}",
|
| 1011 |
-
gr.update(visible=False),
|
| 1012 |
-
gr.update(visible=False, choices=[]),
|
| 1013 |
-
gr.update(visible=False, choices=[]),
|
| 1014 |
-
gr.update(visible=False),
|
| 1015 |
-
"", "", "", "", "", "", {},
|
| 1016 |
-
error_msg
|
| 1017 |
-
)
|
| 1018 |
-
|
| 1019 |
-
def fill_missing_info(event_dict_storage, selected_date, selected_time):
|
| 1020 |
-
"""填充缺失的資訊"""
|
| 1021 |
-
if not event_dict_storage:
|
| 1022 |
-
return (
|
| 1023 |
-
"❌ 沒有事件資料",
|
| 1024 |
-
gr.update(visible=False),
|
| 1025 |
-
gr.update(visible=False, choices=[]),
|
| 1026 |
-
gr.update(visible=False, choices=[]),
|
| 1027 |
-
gr.update(visible=False),
|
| 1028 |
-
"", "", "", "", "", "",
|
| 1029 |
-
{}
|
| 1030 |
-
)
|
| 1031 |
-
|
| 1032 |
-
# 更新日期和時間
|
| 1033 |
-
if selected_date:
|
| 1034 |
-
# 從選項中提取日期字串(例如:"明天 (2026-01-25)" -> "2026-01-25")
|
| 1035 |
-
if "(" in selected_date:
|
| 1036 |
-
date_str = selected_date.split("(")[1].split(")")[0]
|
| 1037 |
-
else:
|
| 1038 |
-
date_str = selected_date
|
| 1039 |
-
else:
|
| 1040 |
-
date_str = event_dict_storage.get("date", "今天")
|
| 1041 |
-
|
| 1042 |
-
if selected_time:
|
| 1043 |
-
time_str = selected_time
|
| 1044 |
-
else:
|
| 1045 |
-
time_str = "09:00" # 預設時間
|
| 1046 |
-
|
| 1047 |
-
# 重新解析日期和時間
|
| 1048 |
-
from ..agents.calendar_agent import parse_datetime
|
| 1049 |
-
start_datetime, end_datetime = parse_datetime(date_str, time_str)
|
| 1050 |
-
|
| 1051 |
-
# 更新事件字典
|
| 1052 |
-
event_dict_storage["start_datetime"] = start_datetime
|
| 1053 |
-
event_dict_storage["end_datetime"] = end_datetime
|
| 1054 |
-
|
| 1055 |
-
return (
|
| 1056 |
-
"✅ 資訊已補充,請檢查並創建事件",
|
| 1057 |
-
gr.update(visible=False), # 隱藏缺失資訊區域
|
| 1058 |
-
gr.update(visible=False, choices=[]),
|
| 1059 |
-
gr.update(visible=False, choices=[]),
|
| 1060 |
-
gr.update(visible=False),
|
| 1061 |
-
event_dict_storage.get("summary", ""),
|
| 1062 |
-
start_datetime,
|
| 1063 |
-
end_datetime,
|
| 1064 |
-
event_dict_storage.get("description", ""),
|
| 1065 |
-
event_dict_storage.get("location", ""),
|
| 1066 |
-
event_dict_storage.get("attendees", ""),
|
| 1067 |
-
event_dict_storage
|
| 1068 |
-
)
|
| 1069 |
-
|
| 1070 |
-
def create_event(summary, start_datetime, end_datetime, description, location, attendees):
|
| 1071 |
-
"""創建行事曆事件"""
|
| 1072 |
-
if not summary or not summary.strip():
|
| 1073 |
-
return "❌ 請輸入事件標題", "❌ 請輸入事件標題"
|
| 1074 |
-
|
| 1075 |
-
if not start_datetime or not start_datetime.strip():
|
| 1076 |
-
return "❌ 請輸入開始時間", "❌ 請輸入開始時間"
|
| 1077 |
-
|
| 1078 |
-
if not end_datetime or not end_datetime.strip():
|
| 1079 |
-
return "❌ 請輸入結束時間", "❌ 請輸入結束時間"
|
| 1080 |
-
|
| 1081 |
-
try:
|
| 1082 |
-
status_msg = "🔄 正在創建事件..."
|
| 1083 |
-
|
| 1084 |
-
# 構建事件字典
|
| 1085 |
-
event_dict = {
|
| 1086 |
-
"summary": summary.strip(),
|
| 1087 |
-
"start_datetime": start_datetime.strip(),
|
| 1088 |
-
"end_datetime": end_datetime.strip(),
|
| 1089 |
-
"description": description.strip() if description else "",
|
| 1090 |
-
"location": location.strip() if location else "",
|
| 1091 |
-
"attendees": attendees.strip() if attendees else "",
|
| 1092 |
-
"timezone": "Asia/Taipei"
|
| 1093 |
-
}
|
| 1094 |
-
|
| 1095 |
-
# 創建事件
|
| 1096 |
-
result = create_calendar_draft(event_dict)
|
| 1097 |
-
|
| 1098 |
-
return "✅ 事件已創建", result
|
| 1099 |
-
except Exception as e:
|
| 1100 |
-
error_msg = f"❌ 創建事件時發生錯誤:{str(e)}"
|
| 1101 |
-
print(f"Calendar Tool 錯誤:{e}")
|
| 1102 |
-
import traceback
|
| 1103 |
-
traceback.print_exc()
|
| 1104 |
-
return "❌ 發生錯誤", error_msg
|
| 1105 |
-
|
| 1106 |
-
def clear_calendar():
|
| 1107 |
-
"""清除行事曆相關輸入和輸出"""
|
| 1108 |
-
return (
|
| 1109 |
-
"", # prompt
|
| 1110 |
-
"等待操作...", # status
|
| 1111 |
-
"等待生成事件...", # reflection_display
|
| 1112 |
-
gr.update(visible=False), # missing_info_group
|
| 1113 |
-
gr.update(visible=False, choices=[]), # missing_date
|
| 1114 |
-
gr.update(visible=False, choices=[]), # missing_time
|
| 1115 |
-
gr.update(visible=False), # fill_missing_btn
|
| 1116 |
-
"", "", "", "", "", "", # event fields
|
| 1117 |
-
{}, # event_dict_storage
|
| 1118 |
-
"" # result
|
| 1119 |
-
)
|
| 1120 |
-
|
| 1121 |
-
# 綁定事件
|
| 1122 |
-
generate_draft_btn.click(
|
| 1123 |
-
fn=generate_draft,
|
| 1124 |
-
inputs=[calendar_prompt_input],
|
| 1125 |
-
outputs=[
|
| 1126 |
-
calendar_status_display,
|
| 1127 |
-
calendar_reflection_display,
|
| 1128 |
-
missing_info_group,
|
| 1129 |
-
missing_date_display,
|
| 1130 |
-
missing_time_display,
|
| 1131 |
-
fill_missing_btn,
|
| 1132 |
-
event_summary_display,
|
| 1133 |
-
event_start_display,
|
| 1134 |
-
event_end_display,
|
| 1135 |
-
event_description_display,
|
| 1136 |
-
event_location_display,
|
| 1137 |
-
event_attendees_display,
|
| 1138 |
-
event_dict_storage,
|
| 1139 |
-
calendar_result_display
|
| 1140 |
-
]
|
| 1141 |
-
)
|
| 1142 |
-
|
| 1143 |
-
# 綁定快速選擇按鈕(自動填充提示並生成草稿)
|
| 1144 |
-
quick_outputs = [
|
| 1145 |
-
calendar_prompt_input, # 更新提示輸入框
|
| 1146 |
-
calendar_status_display,
|
| 1147 |
-
calendar_reflection_display,
|
| 1148 |
-
missing_info_group,
|
| 1149 |
-
missing_date_display,
|
| 1150 |
-
missing_time_display,
|
| 1151 |
-
fill_missing_btn,
|
| 1152 |
-
event_summary_display,
|
| 1153 |
-
event_start_display,
|
| 1154 |
-
event_end_display,
|
| 1155 |
-
event_description_display,
|
| 1156 |
-
event_location_display,
|
| 1157 |
-
event_attendees_display,
|
| 1158 |
-
event_dict_storage,
|
| 1159 |
-
calendar_result_display
|
| 1160 |
-
]
|
| 1161 |
-
|
| 1162 |
-
quick_meeting_btn.click(fn=quick_select_meeting, outputs=quick_outputs)
|
| 1163 |
-
quick_client_btn.click(fn=quick_select_client, outputs=quick_outputs)
|
| 1164 |
-
quick_lunch_btn.click(fn=quick_select_lunch, outputs=quick_outputs)
|
| 1165 |
-
quick_oneonone_btn.click(fn=quick_select_oneonone, outputs=quick_outputs)
|
| 1166 |
-
quick_project_btn.click(fn=quick_select_project, outputs=quick_outputs)
|
| 1167 |
-
quick_training_btn.click(fn=quick_select_training, outputs=quick_outputs)
|
| 1168 |
-
quick_social_btn.click(fn=quick_select_social, outputs=quick_outputs)
|
| 1169 |
-
quick_custom_btn.click(fn=quick_select_custom, outputs=quick_outputs)
|
| 1170 |
-
|
| 1171 |
-
fill_missing_btn.click(
|
| 1172 |
-
fn=fill_missing_info,
|
| 1173 |
-
inputs=[event_dict_storage, missing_date_display, missing_time_display],
|
| 1174 |
-
outputs=[
|
| 1175 |
-
calendar_status_display,
|
| 1176 |
-
missing_info_group,
|
| 1177 |
-
missing_date_display,
|
| 1178 |
-
missing_time_display,
|
| 1179 |
-
fill_missing_btn,
|
| 1180 |
-
event_summary_display,
|
| 1181 |
-
event_start_display,
|
| 1182 |
-
event_end_display,
|
| 1183 |
-
event_description_display,
|
| 1184 |
-
event_location_display,
|
| 1185 |
-
event_attendees_display,
|
| 1186 |
-
event_dict_storage
|
| 1187 |
-
]
|
| 1188 |
-
)
|
| 1189 |
-
|
| 1190 |
-
create_event_btn.click(
|
| 1191 |
-
fn=create_event,
|
| 1192 |
-
inputs=[
|
| 1193 |
-
event_summary_display,
|
| 1194 |
-
event_start_display,
|
| 1195 |
-
event_end_display,
|
| 1196 |
-
event_description_display,
|
| 1197 |
-
event_location_display,
|
| 1198 |
-
event_attendees_display
|
| 1199 |
-
],
|
| 1200 |
-
outputs=[calendar_status_display, calendar_result_display]
|
| 1201 |
-
)
|
| 1202 |
-
|
| 1203 |
-
clear_calendar_btn.click(
|
| 1204 |
-
fn=clear_calendar,
|
| 1205 |
-
outputs=[
|
| 1206 |
-
calendar_prompt_input,
|
| 1207 |
-
calendar_status_display,
|
| 1208 |
-
calendar_reflection_display,
|
| 1209 |
-
missing_info_group,
|
| 1210 |
-
missing_date_display,
|
| 1211 |
-
missing_time_display,
|
| 1212 |
-
fill_missing_btn,
|
| 1213 |
-
event_summary_display,
|
| 1214 |
-
event_start_display,
|
| 1215 |
-
event_end_display,
|
| 1216 |
-
event_description_display,
|
| 1217 |
-
event_location_display,
|
| 1218 |
-
event_attendees_display,
|
| 1219 |
-
event_dict_storage,
|
| 1220 |
-
calendar_result_display
|
| 1221 |
-
]
|
| 1222 |
-
)
|
| 1223 |
-
|
| 1224 |
-
# 示例
|
| 1225 |
-
gr.Examples(
|
| 1226 |
-
examples=[
|
| 1227 |
-
"明天下午2點團隊會議,討論項目進度,地點在會議室A,參與者包括john@example.com",
|
| 1228 |
-
"2026-01-25 上午9點產品發布會,介紹新功能和改進,地點在總部大樓",
|
| 1229 |
-
"後天下午3點客戶會議,討論合作細節,參與者包括客戶代表",
|
| 1230 |
-
"下週一上午10點技術分享會,分享最新的 AI 技術,地點在研發中心"
|
| 1231 |
-
],
|
| 1232 |
-
inputs=[calendar_prompt_input]
|
| 1233 |
-
)
|
| 1234 |
-
|
| 1235 |
-
# 頁腳說明
|
| 1236 |
-
gr.Markdown(
|
| 1237 |
-
"""
|
| 1238 |
-
---
|
| 1239 |
-
**注意事項:**
|
| 1240 |
-
1. 使用 Google Calendar API 管理行事曆事件
|
| 1241 |
-
2. 首次使用需要在專案根目錄放置 `credentials.json`(從 Google Cloud Console 下載的 OAuth2 憑證)
|
| 1242 |
-
3. 首次運行時會自動開啟瀏覽器進行授權,授權後會生成 `token.json` 文件
|
| 1243 |
-
4. 事件內容由 AI 自動生成,請在創建前檢查結果
|
| 1244 |
-
5. 在提示中包含所有資訊:事件、日期、時間、地點、參與者
|
| 1245 |
-
6. 如果缺少日期或時間,系統會顯示下拉選單讓您選擇
|
| 1246 |
-
7. 日期格式支援:YYYY-MM-DD(例如:2026-01-25)或相對日期(今天��明天、後天)
|
| 1247 |
-
8. 時間格式支援:24小時制(14:00)或12小時制(2:00 PM)
|
| 1248 |
-
|
| 1249 |
-
**設置步驟:**
|
| 1250 |
-
- 前往 [Google Cloud Console](https://console.cloud.google.com/) 創建專案
|
| 1251 |
-
- 啟用 Google Calendar API
|
| 1252 |
-
- 創建 OAuth2 憑證並下載為 `credentials.json`
|
| 1253 |
-
- 將 `credentials.json` 放在專案根目錄
|
| 1254 |
-
- 確保授予 Calendar API 的完整存取權限
|
| 1255 |
-
"""
|
| 1256 |
-
)
|
| 1257 |
|
|
|
|
|
|
|
|
|
| 5 |
import uuid
|
| 6 |
import re
|
| 7 |
import time
|
| 8 |
+
import json
|
| 9 |
+
import os
|
| 10 |
from typing import Iterator, Tuple
|
| 11 |
import gradio as gr
|
| 12 |
from langchain_core.messages import HumanMessage
|
| 13 |
|
| 14 |
# graph 和 rag_retriever 將從外部傳入,不在這裡導入
|
| 15 |
from ..utils.llm_utils import get_llm_type, is_using_local_llm
|
| 16 |
+
from .email_interface import _create_email_interface
|
| 17 |
+
from .calendar_interface import _create_calendar_interface
|
| 18 |
+
from .private_file_rag_interface import _create_private_file_rag_interface
|
| 19 |
|
| 20 |
|
| 21 |
def run_research_agent(query: str, graph, thread_id: str = None) -> Iterator[Tuple[str, str, str, str, str]]:
|
|
|
|
| 209 |
gr.Markdown(
|
| 210 |
"""
|
| 211 |
<div class="header">
|
| 212 |
+
<h1>🚀 Deep Research Agent with RAG</h1>
|
| 213 |
<p><strong>功能特色:</strong></p>
|
| 214 |
+
<p>📊 股票資訊查詢 | 🌐 網路搜尋 | 📚 PDF 知識庫查詢(Tree of Thoughts 論文)| 📧 智能郵件助手 | 📅 智能行事曆管理 | 📄 私有文件 RAG 問答</p>
|
| 215 |
<p><strong>智能規劃:</strong> 系統會根據問題類型自動選擇合適的研究工具</p>
|
| 216 |
<p><strong>本地模型:</strong> 使用 MLX 本地模型,保護隱私,無需 API 金鑰</p>
|
| 217 |
</div>
|
|
|
|
| 232 |
# Tab 3: Calendar Tool
|
| 233 |
with gr.Tab("📅 Calendar Tool"):
|
| 234 |
_create_calendar_interface()
|
| 235 |
+
|
| 236 |
+
# Tab 4: Private File RAG
|
| 237 |
+
with gr.Tab("📚 Private File RAG"):
|
| 238 |
+
_create_private_file_rag_interface()
|
| 239 |
|
| 240 |
return demo
|
| 241 |
|
|
|
|
| 372 |
)
|
| 373 |
|
| 374 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 375 |
|
| 376 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
|
| 378 |
+
|
| 379 |
+
|
deep_agent_rag/ui/private_file_rag_interface.py
ADDED
|
@@ -0,0 +1,663 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# deep_agent_rag/ui/private_file_rag_interface.py
|
| 2 |
+
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import re
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
|
| 9 |
+
from ..rag.private_file_rag import get_private_rag_instance, reset_private_rag_instance
|
| 10 |
+
# Assuming is_using_local_llm might be used for warnings/status, similar to email_interface
|
| 11 |
+
# from ..utils.llm_utils import is_using_local_llm
|
| 12 |
+
|
| 13 |
+
# Agent log path for debugging (if needed)
|
| 14 |
+
log_path = "/Users/matthuang/Desktop/Deep_Agentic_AI_Tool/.cursor/debug.log"
|
| 15 |
+
|
| 16 |
+
def _create_private_file_rag_interface():
|
| 17 |
+
"""創建私有文件 RAG 界面(對話式 Chatbot)"""
|
| 18 |
+
gr.Markdown(
|
| 19 |
+
"""
|
| 20 |
+
### 📚 私有文件 RAG 對話系統
|
| 21 |
+
|
| 22 |
+
上傳您的私有文件(PDF、DOCX、TXT),系統會自動建立 RAG 知識庫,讓 AI 可以回答關於這些文件的問題。
|
| 23 |
+
支持多輪對話,AI 會記住之前的對話內容,提供更連貫的回答。
|
| 24 |
+
|
| 25 |
+
**使用方式:**
|
| 26 |
+
1. 上傳一個或多個文件(PDF、DOCX、TXT)
|
| 27 |
+
2. 點擊「處理文件」按鈕,系統會自動處理文件並建立 RAG 系統
|
| 28 |
+
3. 在對話框中輸入您的問題,按 Enter 或點擊「發送」按鈕
|
| 29 |
+
4. AI 會基於上傳的文件回答問題,支持多輪對話
|
| 30 |
+
|
| 31 |
+
**功能特色:**
|
| 32 |
+
- 💬 **對話式界面** :類似 Gemini 的對話體驗,支持多輪對話
|
| 33 |
+
- 📄 支持多種文件格式:PDF、DOCX、TXT
|
| 34 |
+
- 🔍 使用混合搜尋(BM25 + 向量檢索)提升檢索準確度
|
| 35 |
+
- 🎯 可選重排序功能,進一步優化結果
|
| 36 |
+
- 🧠 支持語義分塊,保持語義完整性
|
| 37 |
+
- 🌐 自動檢測文檔類型並調整回答風格
|
| 38 |
+
|
| 39 |
+
**LLM 使用策略:**
|
| 40 |
+
- 🥇 **優先使用 Groq API** :如果配置了 API 金鑰,優先使用 Groq(速度快、質量高)
|
| 41 |
+
- 🥈 **其次使用 Ollama** :如果 Groq 不可用,自動切換到 Ollama 本地模型
|
| 42 |
+
- 🥉 **最後使用 MLX** :如果前兩者都不可用,使用 MLX 本地模型作為備選
|
| 43 |
+
- 💡 **自動切換** :系統會根據 API 額度、服務狀態等自動選擇最合適的 LLM
|
| 44 |
+
|
| 45 |
+
**注意:** 此功能需要 Learn_RAG 項目在正確的位置
|
| 46 |
+
"""
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# 對話歷史狀態
|
| 50 |
+
chat_history = gr.State(value=[])
|
| 51 |
+
|
| 52 |
+
with gr.Row():
|
| 53 |
+
# 左側:文件上傳和設置
|
| 54 |
+
with gr.Column(scale=1):
|
| 55 |
+
# 文件上傳區域
|
| 56 |
+
file_upload = gr.File(
|
| 57 |
+
label="📁 上傳文件(PDF、DOCX、TXT)",
|
| 58 |
+
file_count="multiple",
|
| 59 |
+
file_types=[ ".pdf", ".docx", ".doc", ".txt"]
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# 處理按鈕
|
| 63 |
+
with gr.Row():
|
| 64 |
+
process_btn = gr.Button("📝 處理文件", variant="primary", scale=1)
|
| 65 |
+
clear_files_btn = gr.Button("🗑️ 清除所有", variant="secondary", scale=1)
|
| 66 |
+
|
| 67 |
+
# 處理狀態
|
| 68 |
+
process_status = gr.Textbox(
|
| 69 |
+
label="📊 處理狀態",
|
| 70 |
+
value="等待上傳文件...",
|
| 71 |
+
interactive=False,
|
| 72 |
+
lines=2
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
# 設置區域(使用 Accordion 摺疊)
|
| 76 |
+
with gr.Accordion("⚙️ 進階設置", open=False):
|
| 77 |
+
# 處理選項
|
| 78 |
+
use_semantic_chunking = gr.Checkbox(
|
| 79 |
+
label="使用語義分塊(推薦)",
|
| 80 |
+
value=False,
|
| 81 |
+
info="語義分塊能保持語義完整性,但處理時間較長"
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# 分塊參數調整(字符分塊模式)
|
| 85 |
+
gr.Markdown("**📏 字符分塊參數(僅在未使用語義分塊時有效)**")
|
| 86 |
+
chunk_size_slider = gr.Slider(
|
| 87 |
+
minimum=200,
|
| 88 |
+
maximum=1500,
|
| 89 |
+
value=500,
|
| 90 |
+
step=50,
|
| 91 |
+
label="分塊大小(字符數)",
|
| 92 |
+
info="建議:300-800"
|
| 93 |
+
)
|
| 94 |
+
chunk_overlap_slider = gr.Slider(
|
| 95 |
+
minimum=0,
|
| 96 |
+
maximum=300,
|
| 97 |
+
value=100,
|
| 98 |
+
step=25,
|
| 99 |
+
label="分塊重疊(字符數)",
|
| 100 |
+
info="建議:chunk_size 的 15-25%"
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# 語義分塊參數調整(僅在使用語義分塊時有效)
|
| 104 |
+
gr.Markdown("**🔬 語義分塊參數(僅在使用語義分塊時有效)**")
|
| 105 |
+
semantic_threshold_slider = gr.Slider(
|
| 106 |
+
minimum=0.5,
|
| 107 |
+
maximum=2.5,
|
| 108 |
+
value=1.0,
|
| 109 |
+
step=0.1,
|
| 110 |
+
label="語義分塊閾值(敏感度)",
|
| 111 |
+
info="建議:0.8-1.2(細粒度)"
|
| 112 |
+
)
|
| 113 |
+
semantic_min_chunk_slider = gr.Slider(
|
| 114 |
+
minimum=50,
|
| 115 |
+
maximum=300,
|
| 116 |
+
value=100,
|
| 117 |
+
step=25,
|
| 118 |
+
label="最小分塊大小(字符數)",
|
| 119 |
+
info="建議:50-200"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# RAG 方法選擇
|
| 123 |
+
gr.Markdown("**🎯 RAG 方法選擇**")
|
| 124 |
+
enable_adaptive_selection = gr.Checkbox(
|
| 125 |
+
label="自動選擇最佳 RAG 方法(推薦)",
|
| 126 |
+
value=True,
|
| 127 |
+
info="系統會根據查詢和文件特征自動選擇最合適的 RAG 方法"
|
| 128 |
+
)
|
| 129 |
+
manual_rag_method = gr.Dropdown(
|
| 130 |
+
choices=[
|
| 131 |
+
"basic",
|
| 132 |
+
"subquery",
|
| 133 |
+
"hyde",
|
| 134 |
+
"step_back",
|
| 135 |
+
"hybrid_subquery_hyde",
|
| 136 |
+
"triple_hybrid"
|
| 137 |
+
],
|
| 138 |
+
value="basic",
|
| 139 |
+
label="手動選擇 RAG 方法",
|
| 140 |
+
info="僅在自動選擇關閉時生效",
|
| 141 |
+
visible=False
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# 查詢選項
|
| 145 |
+
top_k_slider = gr.Slider(
|
| 146 |
+
minimum=1,
|
| 147 |
+
maximum=10,
|
| 148 |
+
value=3,
|
| 149 |
+
step=1,
|
| 150 |
+
label="返回結果數量"
|
| 151 |
+
)
|
| 152 |
+
use_llm_checkbox = gr.Checkbox(
|
| 153 |
+
label="使用 LLM 生成回答",
|
| 154 |
+
value=True
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# 右側:對話界面
|
| 158 |
+
with gr.Column(scale=2):
|
| 159 |
+
# Chatbot 組件
|
| 160 |
+
# #region agent log
|
| 161 |
+
try:
|
| 162 |
+
with open(log_path, "a", encoding="utf-8") as f:
|
| 163 |
+
log_entry = {
|
| 164 |
+
"sessionId": "debug-session",
|
| 165 |
+
"runId": "run1",
|
| 166 |
+
"hypothesisId": "A",
|
| 167 |
+
"location": "private_file_rag_interface.py:1409", # Adjusted line number
|
| 168 |
+
"message": "Before Chatbot creation",
|
| 169 |
+
"data": {
|
| 170 |
+
"gradio_version": gr.__version__ if hasattr(gr, '__version__') else "unknown"
|
| 171 |
+
},
|
| 172 |
+
"timestamp": int(time.time() * 1000)
|
| 173 |
+
}
|
| 174 |
+
f.write(json.dumps(log_entry, ensure_ascii=False) + "\n")
|
| 175 |
+
except:
|
| 176 |
+
pass
|
| 177 |
+
# #endregion
|
| 178 |
+
|
| 179 |
+
# 創建 Chatbot(移除不支持的參數:show_copy_button 和 avatar_images)
|
| 180 |
+
# #region agent log
|
| 181 |
+
try:
|
| 182 |
+
with open(log_path, "a", encoding="utf-8") as f:
|
| 183 |
+
log_entry = {
|
| 184 |
+
"sessionId": "debug-session",
|
| 185 |
+
"runId": "run1",
|
| 186 |
+
"hypothesisId": "A",
|
| 187 |
+
"location": "private_file_rag_interface.py:1430", # Adjusted line number
|
| 188 |
+
"message": "Creating Chatbot with minimal params",
|
| 189 |
+
"data": {"params": ["label", "height"]},
|
| 190 |
+
"timestamp": int(time.time() * 1000)
|
| 191 |
+
}
|
| 192 |
+
f.write(json.dumps(log_entry, ensure_ascii=False) + "\n")
|
| 193 |
+
except:
|
| 194 |
+
pass
|
| 195 |
+
# #endregion
|
| 196 |
+
|
| 197 |
+
try:
|
| 198 |
+
chatbot = gr.Chatbot(
|
| 199 |
+
label="💬 對話",
|
| 200 |
+
height=500
|
| 201 |
+
)
|
| 202 |
+
# #region agent log
|
| 203 |
+
try:
|
| 204 |
+
with open(log_path, "a", encoding="utf-8") as f:
|
| 205 |
+
log_entry = {
|
| 206 |
+
"sessionId": "debug-session",
|
| 207 |
+
"runId": "run1",
|
| 208 |
+
"hypothesisId": "A",
|
| 209 |
+
"location": "private_file_rag_interface.py:1448", # Adjusted line number
|
| 210 |
+
"message": "Chatbot created successfully",
|
| 211 |
+
"data": {"success": True},
|
| 212 |
+
"timestamp": int(time.time() * 1000)
|
| 213 |
+
}
|
| 214 |
+
f.write(json.dumps(log_entry, ensure_ascii=False) + "\n")
|
| 215 |
+
except:
|
| 216 |
+
pass
|
| 217 |
+
# #endregion
|
| 218 |
+
except Exception as e:
|
| 219 |
+
# #region agent log
|
| 220 |
+
try:
|
| 221 |
+
with open(log_path, "a", encoding="utf-8") as f:
|
| 222 |
+
log_entry = {
|
| 223 |
+
"sessionId": "debug-session",
|
| 224 |
+
"runId": "run1",
|
| 225 |
+
"hypothesisId": "A",
|
| 226 |
+
"location": "private_file_rag_interface.py:1460", # Adjusted line number
|
| 227 |
+
"message": "Chatbot creation failed",
|
| 228 |
+
"data": {
|
| 229 |
+
"error_type": type(e).__name__,
|
| 230 |
+
"error_message": str(e)
|
| 231 |
+
},
|
| 232 |
+
"timestamp": int(time.time() * 1000)
|
| 233 |
+
}
|
| 234 |
+
f.write(json.dumps(log_entry, ensure_ascii=False) + "\n")
|
| 235 |
+
except:
|
| 236 |
+
pass
|
| 237 |
+
# #endregion
|
| 238 |
+
raise
|
| 239 |
+
|
| 240 |
+
# 輸入框
|
| 241 |
+
msg = gr.Textbox(
|
| 242 |
+
label="輸入問題",
|
| 243 |
+
placeholder="輸入您的問題,按 Enter 發送...",
|
| 244 |
+
lines=2,
|
| 245 |
+
scale=4
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# 按鈕區域
|
| 249 |
+
with gr.Row():
|
| 250 |
+
submit_btn = gr.Button("📤 發送", variant="primary", scale=1)
|
| 251 |
+
clear_chat_btn = gr.Button("🗑️ 清除對話", variant="secondary", scale=1)
|
| 252 |
+
|
| 253 |
+
# 查詢狀態
|
| 254 |
+
query_status = gr.Textbox(
|
| 255 |
+
label="📊 狀態",
|
| 256 |
+
value="等待查詢...",
|
| 257 |
+
interactive=False,
|
| 258 |
+
lines=1
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
# 輔助函數:轉換 Gradio 歷史格式(dict)和 RAG 歷史格式(tuple)
|
| 262 |
+
def history_dict_to_tuple(history_dict):
|
| 263 |
+
"""
|
| 264 |
+
將 Gradio 歷史格式(List[Dict])轉換為 RAG 歷史格式(List[Tuple[str, str]])
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
history_dict: Gradio 格式的歷史,每個元素為 {"role": "user"/"assistant", "content": "..."}
|
| 268 |
+
|
| 269 |
+
Returns:
|
| 270 |
+
RAG 格式的歷史,每個元素為 (user_message, assistant_message)
|
| 271 |
+
"""
|
| 272 |
+
if not history_dict:
|
| 273 |
+
return []
|
| 274 |
+
|
| 275 |
+
conversation_history = []
|
| 276 |
+
current_user_msg = None
|
| 277 |
+
|
| 278 |
+
for msg in history_dict:
|
| 279 |
+
if isinstance(msg, dict):
|
| 280 |
+
role = msg.get("role", "")
|
| 281 |
+
content = msg.get("content", "")
|
| 282 |
+
|
| 283 |
+
if role == "user":
|
| 284 |
+
current_user_msg = content
|
| 285 |
+
elif role == "assistant" and current_user_msg is not None:
|
| 286 |
+
conversation_history.append((current_user_msg, content))
|
| 287 |
+
current_user_msg = None
|
| 288 |
+
elif isinstance(msg, tuple) and len(msg) == 2:
|
| 289 |
+
# 如果已經是 tuple 格式,直接使用(向後兼容)
|
| 290 |
+
conversation_history.append(msg)
|
| 291 |
+
|
| 292 |
+
return conversation_history
|
| 293 |
+
|
| 294 |
+
def history_tuple_to_dict(history_tuple):
|
| 295 |
+
"""
|
| 296 |
+
將 RAG 歷史格式(List[Tuple[str, str]])轉換為 Gradio 歷史格式(List[Dict])
|
| 297 |
+
|
| 298 |
+
Args:
|
| 299 |
+
history_tuple: RAG 格式的歷史,每個元素為 (user_message, assistant_message)
|
| 300 |
+
|
| 301 |
+
Returns:
|
| 302 |
+
Gradio 格式的歷史,每個元素為 {"role": "user"/"assistant", "content": "..."}
|
| 303 |
+
"""
|
| 304 |
+
if not history_tuple:
|
| 305 |
+
return []
|
| 306 |
+
|
| 307 |
+
history_dict = []
|
| 308 |
+
for msg in history_tuple:
|
| 309 |
+
if isinstance(msg, dict):
|
| 310 |
+
# 如果已經是 dict 格式,直接使用
|
| 311 |
+
history_dict.append(msg)
|
| 312 |
+
elif isinstance(msg, tuple) and len(msg) == 2:
|
| 313 |
+
# 轉換 tuple 為 dict 格式
|
| 314 |
+
user_msg, assistant_msg = msg
|
| 315 |
+
history_dict.append({"role": "user", "content": user_msg})
|
| 316 |
+
history_dict.append({"role": "assistant", "content": assistant_msg})
|
| 317 |
+
|
| 318 |
+
return history_dict
|
| 319 |
+
|
| 320 |
+
def ensure_dict_format(history):
|
| 321 |
+
"""
|
| 322 |
+
確保歷史是 Gradio dict 格式
|
| 323 |
+
|
| 324 |
+
Args:
|
| 325 |
+
history: 歷史列表(可能是 dict 或 tuple 格式,也可能是 None)
|
| 326 |
+
|
| 327 |
+
Returns:
|
| 328 |
+
Gradio 格式的歷史(List[Dict])
|
| 329 |
+
"""
|
| 330 |
+
if not history:
|
| 331 |
+
return []
|
| 332 |
+
|
| 333 |
+
# 檢查第一個元素的類型來判斷格式
|
| 334 |
+
try:
|
| 335 |
+
if isinstance(history[0], dict):
|
| 336 |
+
return history
|
| 337 |
+
elif isinstance(history[0], tuple):
|
| 338 |
+
return history_tuple_to_dict(history)
|
| 339 |
+
else:
|
| 340 |
+
# 未知格式,返回空列表
|
| 341 |
+
return []
|
| 342 |
+
except (IndexError, TypeError):
|
| 343 |
+
# 如果 history 為空或無法索引,返回空列表
|
| 344 |
+
return []
|
| 345 |
+
|
| 346 |
+
# 事件處理函數
|
| 347 |
+
def process_files(files, use_semantic, chunk_size, chunk_overlap, semantic_threshold, semantic_min_chunk):
|
| 348 |
+
"""
|
| 349 |
+
處理上傳的文件
|
| 350 |
+
|
| 351 |
+
Args:
|
| 352 |
+
files: 上傳的文件列表
|
| 353 |
+
use_semantic: 是否使用語義分塊
|
| 354 |
+
chunk_size: 字符分塊大小(僅用於字符分塊模式)
|
| 355 |
+
chunk_overlap: 字符分塊重疊大小(僅用於字符分塊模式)
|
| 356 |
+
semantic_threshold: 語義分塊閾值(僅用於語義分塊模式)
|
| 357 |
+
semantic_min_chunk: 語義分塊最小 chunk 大小(僅用於語義分塊模式)
|
| 358 |
+
"""
|
| 359 |
+
if not files:
|
| 360 |
+
return "❌ 請先上傳文件", "等待上傳文件..."
|
| 361 |
+
|
| 362 |
+
try:
|
| 363 |
+
# 獲取 RAG ��例
|
| 364 |
+
rag = get_private_rag_instance()
|
| 365 |
+
|
| 366 |
+
# 更新配置
|
| 367 |
+
rag.use_semantic_chunking = use_semantic
|
| 368 |
+
|
| 369 |
+
# 更新分塊參數(根據分塊模式選擇)
|
| 370 |
+
if not use_semantic:
|
| 371 |
+
# 字符分塊模式:更新字符分塊參數
|
| 372 |
+
rag.chunk_size = int(chunk_size)
|
| 373 |
+
rag.chunk_overlap = int(chunk_overlap)
|
| 374 |
+
print(f"📏 使用字符分塊:chunk_size={rag.chunk_size}, chunk_overlap={rag.chunk_overlap}")
|
| 375 |
+
else:
|
| 376 |
+
# 語義分塊模式:更新語義分塊參數
|
| 377 |
+
rag.semantic_threshold = float(semantic_threshold)
|
| 378 |
+
rag.semantic_min_chunk_size = int(semantic_min_chunk)
|
| 379 |
+
print(f"📏 使用語義分塊:threshold={rag.semantic_threshold}, min_chunk_size={rag.semantic_min_chunk_size}")
|
| 380 |
+
|
| 381 |
+
# 處理上傳的文件(Gradio 會自動保存到臨時目錄)
|
| 382 |
+
# Gradio 6.x 返回的是文件路徑字符串列表
|
| 383 |
+
file_paths = []
|
| 384 |
+
|
| 385 |
+
for file in files:
|
| 386 |
+
# Gradio 6.x 返回字符串路徑,舊版本可能返回文件對象
|
| 387 |
+
if isinstance(file, str):
|
| 388 |
+
file_path = file
|
| 389 |
+
elif hasattr(file, 'name'):
|
| 390 |
+
# 舊版本 Gradio 文件對象
|
| 391 |
+
file_path = file.name
|
| 392 |
+
else:
|
| 393 |
+
# 嘗試轉換為字符串
|
| 394 |
+
file_path = str(file)
|
| 395 |
+
|
| 396 |
+
if os.path.exists(file_path):
|
| 397 |
+
file_paths.append(file_path)
|
| 398 |
+
else:
|
| 399 |
+
return f"❌ 文件不存在: {file_path}", "處理失敗"
|
| 400 |
+
|
| 401 |
+
if not file_paths:
|
| 402 |
+
return "❌ 沒有有效的文件路徑", "處理失敗"
|
| 403 |
+
|
| 404 |
+
# 處理文件
|
| 405 |
+
documents, status_msg = rag.process_files(file_paths)
|
| 406 |
+
|
| 407 |
+
if documents:
|
| 408 |
+
return status_msg, "✅ 文件處理完成,可以開始查詢"
|
| 409 |
+
else:
|
| 410 |
+
return status_msg, "❌ 處理失敗"
|
| 411 |
+
|
| 412 |
+
except Exception as e:
|
| 413 |
+
error_msg = f"❌ 處理文件時發生錯誤: {str(e)}"
|
| 414 |
+
print(error_msg)
|
| 415 |
+
import traceback
|
| 416 |
+
traceback.print_exc()
|
| 417 |
+
return error_msg, "❌ 處理失敗"
|
| 418 |
+
|
| 419 |
+
def query_rag_stream(message, history, top_k, use_llm, enable_adaptive, manual_method):
|
| 420 |
+
"""
|
| 421 |
+
查詢 RAG 系統(對話式,流式輸出)
|
| 422 |
+
|
| 423 |
+
Args:
|
| 424 |
+
message: 當前用戶消息
|
| 425 |
+
history: 對話歷史(Gradio 格式:List[Dict] 或 List[Tuple[str, str]])
|
| 426 |
+
top_k: 返回結果數量
|
| 427 |
+
use_llm: 是否使用 LLM 生成回答
|
| 428 |
+
enable_adaptive: 是否啟用自動選擇
|
| 429 |
+
manual_method: 手動選擇的方法(僅在自動選擇關閉時生效)
|
| 430 |
+
|
| 431 |
+
Yields:
|
| 432 |
+
Tuple[history, status_msg]: 逐步更新的對話歷史和狀態訊息
|
| 433 |
+
"""
|
| 434 |
+
if not message or not message.strip():
|
| 435 |
+
yield history, "❌ 請輸入問題"
|
| 436 |
+
return
|
| 437 |
+
|
| 438 |
+
try:
|
| 439 |
+
# 獲取 RAG 實例
|
| 440 |
+
rag = get_private_rag_instance()
|
| 441 |
+
|
| 442 |
+
if not rag.is_initialized:
|
| 443 |
+
error_msg = "❌ RAG 系統尚未初始化,請先處理文件"
|
| 444 |
+
# 確保 history 是 dict 格式
|
| 445 |
+
history = ensure_dict_format(history)
|
| 446 |
+
history.append({"role": "user", "content": message})
|
| 447 |
+
history.append({"role": "assistant", "content": error_msg})
|
| 448 |
+
yield history, error_msg
|
| 449 |
+
return
|
| 450 |
+
|
| 451 |
+
# 設置 RAG 方法選擇參數
|
| 452 |
+
rag.enable_adaptive_selection = enable_adaptive
|
| 453 |
+
if not enable_adaptive:
|
| 454 |
+
rag.selected_rag_method = manual_method
|
| 455 |
+
else:
|
| 456 |
+
rag.selected_rag_method = None
|
| 457 |
+
|
| 458 |
+
# 準備對話歷史:轉換為 RAG 需要的 tuple 格式
|
| 459 |
+
conversation_history = history_dict_to_tuple(history) if history else []
|
| 460 |
+
|
| 461 |
+
# 確保 history 是 dict 格式並添加用戶消息
|
| 462 |
+
history = ensure_dict_format(history)
|
| 463 |
+
history.append({"role": "user", "content": message})
|
| 464 |
+
|
| 465 |
+
# 執行查詢(傳入對話歷史,使用流式輸出)
|
| 466 |
+
if use_llm:
|
| 467 |
+
# 使用流式查詢
|
| 468 |
+
answer_generator = rag.query_stream(
|
| 469 |
+
query=message,
|
| 470 |
+
top_k=int(top_k),
|
| 471 |
+
conversation_history=conversation_history
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
# 初始化回答
|
| 475 |
+
accumulated_answer = ""
|
| 476 |
+
history_with_user = history.copy()
|
| 477 |
+
final_result = {}
|
| 478 |
+
|
| 479 |
+
# 逐步接收流式回答
|
| 480 |
+
for chunk in answer_generator:
|
| 481 |
+
if chunk.get("success") is False:
|
| 482 |
+
error = chunk.get("error", "未知錯誤")
|
| 483 |
+
error_msg = f"❌ 查詢失敗: {error}"
|
| 484 |
+
history_with_user.append({"role": "assistant", "content": error_msg})
|
| 485 |
+
yield history_with_user, error_msg
|
| 486 |
+
return
|
| 487 |
+
|
| 488 |
+
# 保存最後一個 chunk 作為最終結果
|
| 489 |
+
final_result = chunk
|
| 490 |
+
|
| 491 |
+
# 獲取新的回答片段
|
| 492 |
+
new_answer = chunk.get("answer", "")
|
| 493 |
+
if new_answer:
|
| 494 |
+
# 累積回答
|
| 495 |
+
accumulated_answer = new_answer
|
| 496 |
+
# 更新歷史
|
| 497 |
+
history_with_answer = history_with_user.copy()
|
| 498 |
+
history_with_answer.append({"role": "assistant", "content": accumulated_answer})
|
| 499 |
+
yield history_with_answer, "🔄 正在生成回答..."
|
| 500 |
+
|
| 501 |
+
# 獲取最終結果(包含統計信息)
|
| 502 |
+
rag_method = final_result.get("rag_method", "basic")
|
| 503 |
+
stats = final_result.get("stats", {})
|
| 504 |
+
status_msg = f"✅ 查詢完成(方法: {rag_method.upper()})"
|
| 505 |
+
if stats:
|
| 506 |
+
total_time = stats.get("total_time", 0)
|
| 507 |
+
if total_time > 0:
|
| 508 |
+
status_msg += f" | 耗時: {total_time:.2f}秒"
|
| 509 |
+
|
| 510 |
+
# 確保最終回答完整
|
| 511 |
+
if accumulated_answer:
|
| 512 |
+
history_with_answer = history_with_user.copy()
|
| 513 |
+
history_with_answer.append({"role": "assistant", "content": accumulated_answer})
|
| 514 |
+
yield history_with_answer, status_msg
|
| 515 |
+
else:
|
| 516 |
+
error_msg = "⚠️ LLM 未生成回答(可能 LLM 服務未啟動)"
|
| 517 |
+
history_with_answer = history_with_user.copy()
|
| 518 |
+
history_with_answer.append({"role": "assistant", "content": error_msg})
|
| 519 |
+
yield history_with_answer, status_msg
|
| 520 |
+
else:
|
| 521 |
+
# 不使用 LLM,直接返回檢索結果
|
| 522 |
+
result = rag.query(
|
| 523 |
+
query=message,
|
| 524 |
+
top_k=int(top_k),
|
| 525 |
+
use_llm=False,
|
| 526 |
+
conversation_history=conversation_history
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
if not result.get("success"):
|
| 530 |
+
error = result.get("error", "未知錯誤")
|
| 531 |
+
error_msg = f"❌ 查詢失敗: {error}"
|
| 532 |
+
history.append({"role": "assistant", "content": error_msg})
|
| 533 |
+
yield history, error_msg
|
| 534 |
+
return
|
| 535 |
+
|
| 536 |
+
# 格式化檢索結果
|
| 537 |
+
formatted_context = result.get("formatted_context", "")
|
| 538 |
+
answer = f"📄 檢索到的相關內容:\n\n{formatted_context}"
|
| 539 |
+
|
| 540 |
+
# 獲取 RAG 方法信息
|
| 541 |
+
rag_method = result.get("rag_method", "basic")
|
| 542 |
+
stats = result.get("stats", {})
|
| 543 |
+
status_msg = f"✅ 查詢完成(方法: {rag_method.upper()})"
|
| 544 |
+
if stats:
|
| 545 |
+
total_time = stats.get("total_time", 0)
|
| 546 |
+
if total_time > 0:
|
| 547 |
+
status_msg += f" | 耗時: {total_time:.2f}秒"
|
| 548 |
+
|
| 549 |
+
history.append({"role": "assistant", "content": answer})
|
| 550 |
+
yield history, status_msg
|
| 551 |
+
|
| 552 |
+
except Exception as e:
|
| 553 |
+
error_msg = f"❌ 查詢時發生錯誤: {str(e)}"
|
| 554 |
+
print(error_msg)
|
| 555 |
+
import traceback
|
| 556 |
+
traceback.print_exc()
|
| 557 |
+
# 確保 history 是 dict 格式
|
| 558 |
+
history = ensure_dict_format(history)
|
| 559 |
+
if not any(msg.get("role") == "user" and msg.get("content") == message for msg in history):
|
| 560 |
+
history.append({"role": "user", "content": message})
|
| 561 |
+
history.append({"role": "assistant", "content": error_msg})
|
| 562 |
+
yield history, error_msg
|
| 563 |
+
|
| 564 |
+
def clear_chat():
|
| 565 |
+
"""清除對話歷史(不重置 RAG 系統)"""
|
| 566 |
+
return [], "對話已清除"
|
| 567 |
+
|
| 568 |
+
def clear_all():
|
| 569 |
+
"""清除所有內容(包括 RAG 系統)"""
|
| 570 |
+
reset_private_rag_instance()
|
| 571 |
+
empty_history = []
|
| 572 |
+
return (
|
| 573 |
+
None, # file_upload
|
| 574 |
+
False, # use_semantic_chunking
|
| 575 |
+
500, # chunk_size_slider
|
| 576 |
+
100, # chunk_overlap_slider
|
| 577 |
+
1.0, # semantic_threshold_slider
|
| 578 |
+
100, # semantic_min_chunk_slider
|
| 579 |
+
True, # enable_adaptive_selection
|
| 580 |
+
"basic", # manual_rag_method
|
| 581 |
+
"等待上傳文件...", # process_status
|
| 582 |
+
empty_history, # chatbot (對話歷史)
|
| 583 |
+
empty_history, # chat_history (狀態)
|
| 584 |
+
"等��查詢...", # query_status
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
# 綁定事件
|
| 588 |
+
process_btn.click(
|
| 589 |
+
fn=process_files,
|
| 590 |
+
inputs=[
|
| 591 |
+
file_upload,
|
| 592 |
+
use_semantic_chunking,
|
| 593 |
+
chunk_size_slider,
|
| 594 |
+
chunk_overlap_slider,
|
| 595 |
+
semantic_threshold_slider,
|
| 596 |
+
semantic_min_chunk_slider
|
| 597 |
+
],
|
| 598 |
+
outputs=[process_status, query_status]
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
# 自動選擇開關時顯示/隱藏手動選擇下拉菜單
|
| 602 |
+
def toggle_manual_method(enable_adaptive):
|
| 603 |
+
return gr.update(visible=not enable_adaptive)
|
| 604 |
+
|
| 605 |
+
enable_adaptive_selection.change(
|
| 606 |
+
fn=toggle_manual_method,
|
| 607 |
+
inputs=[enable_adaptive_selection],
|
| 608 |
+
outputs=[manual_rag_method]
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
# 提交消息(按鈕點擊或 Enter 鍵)
|
| 612 |
+
def submit_message(message, history, top_k, use_llm, enable_adaptive, manual_method):
|
| 613 |
+
"""提交消息並更新對話歷史(流式輸出)"""
|
| 614 |
+
if not message or not message.strip():
|
| 615 |
+
# 確保 history 是 dict 格式
|
| 616 |
+
history = ensure_dict_format(history)
|
| 617 |
+
return history, history, "", "等待查詢..."
|
| 618 |
+
# 清空輸入框並執行流式查詢
|
| 619 |
+
for new_history, status in query_rag_stream(message, history, top_k, use_llm, enable_adaptive, manual_method):
|
| 620 |
+
yield new_history, new_history, "", status
|
| 621 |
+
|
| 622 |
+
# 綁定提交按鈕和 Enter 鍵
|
| 623 |
+
submit_btn.click(
|
| 624 |
+
fn=submit_message,
|
| 625 |
+
inputs=[msg, chat_history, top_k_slider, use_llm_checkbox, enable_adaptive_selection, manual_rag_method],
|
| 626 |
+
outputs=[chatbot, chat_history, msg, query_status]
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
msg.submit(
|
| 630 |
+
fn=submit_message,
|
| 631 |
+
inputs=[msg, chat_history, top_k_slider, use_llm_checkbox, enable_adaptive_selection, manual_rag_method],
|
| 632 |
+
outputs=[chatbot, chat_history, msg, query_status]
|
| 633 |
+
)
|
| 634 |
+
|
| 635 |
+
# 清除對話按鈕(需要更新 chat_history 狀態)
|
| 636 |
+
def clear_chat_with_state():
|
| 637 |
+
"""清除對話歷史並更新狀態"""
|
| 638 |
+
empty_history = []
|
| 639 |
+
return empty_history, empty_history, "對話已清除"
|
| 640 |
+
|
| 641 |
+
clear_chat_btn.click(
|
| 642 |
+
fn=clear_chat_with_state,
|
| 643 |
+
outputs=[chatbot, chat_history, query_status]
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
# 清除所有按鈕
|
| 647 |
+
clear_files_btn.click(
|
| 648 |
+
fn=clear_all,
|
| 649 |
+
outputs=[
|
| 650 |
+
file_upload,
|
| 651 |
+
use_semantic_chunking,
|
| 652 |
+
chunk_size_slider,
|
| 653 |
+
chunk_overlap_slider,
|
| 654 |
+
semantic_threshold_slider,
|
| 655 |
+
semantic_min_chunk_slider,
|
| 656 |
+
enable_adaptive_selection,
|
| 657 |
+
manual_rag_method,
|
| 658 |
+
process_status,
|
| 659 |
+
chatbot, # 更新 chatbot 顯示
|
| 660 |
+
chat_history, # 更新 chat_history 狀態
|
| 661 |
+
query_status
|
| 662 |
+
]
|
| 663 |
+
)
|
deep_agent_rag/utils/llm_utils.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
| 1 |
"""
|
| 2 |
LLM 工具函數
|
| 3 |
提供 LLM 實例的創建和管理
|
| 4 |
-
優先
|
| 5 |
"""
|
| 6 |
import warnings
|
| 7 |
from typing import Optional
|
| 8 |
from langchain_groq import ChatGroq
|
|
|
|
| 9 |
from ..models import MLXChatModel, load_mlx_model
|
| 10 |
from ..config import (
|
| 11 |
MLX_MAX_TOKENS,
|
|
@@ -14,7 +15,12 @@ from ..config import (
|
|
| 14 |
GROQ_MODEL,
|
| 15 |
GROQ_MAX_TOKENS,
|
| 16 |
GROQ_TEMPERATURE,
|
| 17 |
-
USE_GROQ_FIRST
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
)
|
| 19 |
|
| 20 |
# 全局變量:跟踪當前使用的 LLM 類型
|
|
@@ -29,30 +35,17 @@ def get_llm_type() -> str:
|
|
| 29 |
|
| 30 |
def is_using_local_llm() -> bool:
|
| 31 |
"""檢查是否正在使用本地 LLM"""
|
| 32 |
-
return _current_llm_type
|
| 33 |
|
| 34 |
|
| 35 |
def get_llm():
|
| 36 |
"""
|
| 37 |
獲取 LLM 實例
|
| 38 |
-
優先
|
| 39 |
"""
|
| 40 |
global _current_llm_type, _groq_quota_exceeded
|
| 41 |
|
| 42 |
-
#
|
| 43 |
-
if _groq_quota_exceeded:
|
| 44 |
-
if _current_llm_type != "mlx":
|
| 45 |
-
print("⚠️ 警告:Groq API 額度已用完,已切換到本地 MLX 模型 (Qwen2.5)")
|
| 46 |
-
_current_llm_type = "mlx"
|
| 47 |
-
model, tokenizer = load_mlx_model()
|
| 48 |
-
return MLXChatModel(
|
| 49 |
-
model=model,
|
| 50 |
-
tokenizer=tokenizer,
|
| 51 |
-
max_tokens=MLX_MAX_TOKENS,
|
| 52 |
-
temperature=MLX_TEMPERATURE
|
| 53 |
-
)
|
| 54 |
-
|
| 55 |
-
# 嘗試使用 Groq API
|
| 56 |
if USE_GROQ_FIRST and GROQ_API_KEY:
|
| 57 |
try:
|
| 58 |
groq_llm = ChatGroq(
|
|
@@ -61,48 +54,61 @@ def get_llm():
|
|
| 61 |
max_tokens=GROQ_MAX_TOKENS,
|
| 62 |
temperature=GROQ_TEMPERATURE
|
| 63 |
)
|
| 64 |
-
# 測試連接(通過一個簡單的調用來驗證)
|
| 65 |
-
# 注意:這裡不實際調用,只是創建實例
|
| 66 |
_current_llm_type = "groq"
|
| 67 |
print("✅ 使用 Groq API (優先)")
|
| 68 |
return groq_llm
|
| 69 |
except Exception as e:
|
| 70 |
-
# 如果創建失敗,
|
| 71 |
print(f"⚠️ Groq API 初始化失敗: {e}")
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
|
|
|
| 81 |
)
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
|
| 96 |
def handle_groq_error(error: Exception) -> Optional[MLXChatModel]:
|
| 97 |
"""
|
| 98 |
處理 Groq API 錯誤
|
| 99 |
-
如果是額度用完錯誤,切換到
|
| 100 |
|
| 101 |
Args:
|
| 102 |
error: 捕獲的異常
|
| 103 |
|
| 104 |
Returns:
|
| 105 |
-
如果切換到本地模型,返回 MLXChatModel;否則返回 None
|
| 106 |
"""
|
| 107 |
global _current_llm_type, _groq_quota_exceeded
|
| 108 |
|
|
@@ -121,10 +127,27 @@ def handle_groq_error(error: Exception) -> Optional[MLXChatModel]:
|
|
| 121 |
if any(indicator in error_str for indicator in quota_indicators):
|
| 122 |
if not _groq_quota_exceeded:
|
| 123 |
_groq_quota_exceeded = True
|
| 124 |
-
warning_msg = "⚠️ 警告:Groq API 額度已用完
|
| 125 |
print(warning_msg)
|
| 126 |
warnings.warn(warning_msg, UserWarning)
|
| 127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
_current_llm_type = "mlx"
|
| 129 |
model, tokenizer = load_mlx_model()
|
| 130 |
return MLXChatModel(
|
|
|
|
| 1 |
"""
|
| 2 |
LLM 工具函數
|
| 3 |
提供 LLM 實例的創建和管理
|
| 4 |
+
優先順序:Groq API > Ollama > MLX 模型
|
| 5 |
"""
|
| 6 |
import warnings
|
| 7 |
from typing import Optional
|
| 8 |
from langchain_groq import ChatGroq
|
| 9 |
+
from langchain_ollama import ChatOllama
|
| 10 |
from ..models import MLXChatModel, load_mlx_model
|
| 11 |
from ..config import (
|
| 12 |
MLX_MAX_TOKENS,
|
|
|
|
| 15 |
GROQ_MODEL,
|
| 16 |
GROQ_MAX_TOKENS,
|
| 17 |
GROQ_TEMPERATURE,
|
| 18 |
+
USE_GROQ_FIRST,
|
| 19 |
+
OLLAMA_BASE_URL,
|
| 20 |
+
OLLAMA_MODEL,
|
| 21 |
+
OLLAMA_MAX_TOKENS,
|
| 22 |
+
OLLAMA_TEMPERATURE,
|
| 23 |
+
USE_OLLAMA,
|
| 24 |
)
|
| 25 |
|
| 26 |
# 全局變量:跟踪當前使用的 LLM 類型
|
|
|
|
| 35 |
|
| 36 |
def is_using_local_llm() -> bool:
|
| 37 |
"""檢查是否正在使用本地 LLM"""
|
| 38 |
+
return _current_llm_type in ["mlx", "ollama"] or _groq_quota_exceeded
|
| 39 |
|
| 40 |
|
| 41 |
def get_llm():
|
| 42 |
"""
|
| 43 |
獲取 LLM 實例
|
| 44 |
+
優先順序:Groq API > Ollama > MLX 模型
|
| 45 |
"""
|
| 46 |
global _current_llm_type, _groq_quota_exceeded
|
| 47 |
|
| 48 |
+
# 優先順序 1: Groq API
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
if USE_GROQ_FIRST and GROQ_API_KEY:
|
| 50 |
try:
|
| 51 |
groq_llm = ChatGroq(
|
|
|
|
| 54 |
max_tokens=GROQ_MAX_TOKENS,
|
| 55 |
temperature=GROQ_TEMPERATURE
|
| 56 |
)
|
|
|
|
|
|
|
| 57 |
_current_llm_type = "groq"
|
| 58 |
print("✅ 使用 Groq API (優先)")
|
| 59 |
return groq_llm
|
| 60 |
except Exception as e:
|
| 61 |
+
# 如果創建失敗,繼續嘗試其他選項
|
| 62 |
print(f"⚠️ Groq API 初始化失敗: {e}")
|
| 63 |
+
# 不立即設置 _groq_quota_exceeded,先嘗試 Ollama
|
| 64 |
+
|
| 65 |
+
# 優先順序 2: Ollama (Llama 3.2 或其他模型)
|
| 66 |
+
if USE_OLLAMA:
|
| 67 |
+
try:
|
| 68 |
+
ollama_llm = ChatOllama(
|
| 69 |
+
base_url=OLLAMA_BASE_URL,
|
| 70 |
+
model=OLLAMA_MODEL,
|
| 71 |
+
num_predict=OLLAMA_MAX_TOKENS,
|
| 72 |
+
temperature=OLLAMA_TEMPERATURE,
|
| 73 |
)
|
| 74 |
+
_current_llm_type = "ollama"
|
| 75 |
+
print(f"✅ 使用 Ollama 模型 ({OLLAMA_MODEL})")
|
| 76 |
+
return ollama_llm
|
| 77 |
+
except Exception as e:
|
| 78 |
+
print(f"⚠️ Ollama 初始化失敗: {e}")
|
| 79 |
+
print(" 請確保 Ollama 服務正在運行: ollama serve")
|
| 80 |
+
print(" 或檢查模型是否已下載: ollama pull " + OLLAMA_MODEL)
|
| 81 |
+
|
| 82 |
+
# 優先順序 3: MLX 模型(備援)
|
| 83 |
+
# 如果 Groq 額度用完,記錄狀態
|
| 84 |
+
if _groq_quota_exceeded and _current_llm_type != "mlx":
|
| 85 |
+
print("⚠️ 警告:Groq API 額度已用完,已切換到本地 MLX 模型 (Qwen2.5)")
|
| 86 |
+
elif _current_llm_type != "mlx":
|
| 87 |
+
if not GROQ_API_KEY and not USE_OLLAMA:
|
| 88 |
+
print("ℹ️ 未配置 Groq API 或 Ollama,使用本地 MLX 模型")
|
| 89 |
+
elif not USE_OLLAMA:
|
| 90 |
+
print("ℹ️ Ollama 未啟用,使用本地 MLX 模型作為備援")
|
| 91 |
+
|
| 92 |
+
_current_llm_type = "mlx"
|
| 93 |
+
model, tokenizer = load_mlx_model()
|
| 94 |
+
return MLXChatModel(
|
| 95 |
+
model=model,
|
| 96 |
+
tokenizer=tokenizer,
|
| 97 |
+
max_tokens=MLX_MAX_TOKENS,
|
| 98 |
+
temperature=MLX_TEMPERATURE
|
| 99 |
+
)
|
| 100 |
|
| 101 |
|
| 102 |
def handle_groq_error(error: Exception) -> Optional[MLXChatModel]:
|
| 103 |
"""
|
| 104 |
處理 Groq API 錯誤
|
| 105 |
+
如果是額度用完錯誤,先嘗試切換到 Ollama,否則切換到 MLX 模型
|
| 106 |
|
| 107 |
Args:
|
| 108 |
error: 捕獲的異常
|
| 109 |
|
| 110 |
Returns:
|
| 111 |
+
如果切換到本地模型,返回 ChatOllama 或 MLXChatModel;否則返回 None
|
| 112 |
"""
|
| 113 |
global _current_llm_type, _groq_quota_exceeded
|
| 114 |
|
|
|
|
| 127 |
if any(indicator in error_str for indicator in quota_indicators):
|
| 128 |
if not _groq_quota_exceeded:
|
| 129 |
_groq_quota_exceeded = True
|
| 130 |
+
warning_msg = "⚠️ 警告:Groq API 額度已用完"
|
| 131 |
print(warning_msg)
|
| 132 |
warnings.warn(warning_msg, UserWarning)
|
| 133 |
|
| 134 |
+
# 先嘗試使用 Ollama
|
| 135 |
+
if USE_OLLAMA:
|
| 136 |
+
try:
|
| 137 |
+
ollama_llm = ChatOllama(
|
| 138 |
+
base_url=OLLAMA_BASE_URL,
|
| 139 |
+
model=OLLAMA_MODEL,
|
| 140 |
+
num_predict=OLLAMA_MAX_TOKENS,
|
| 141 |
+
temperature=OLLAMA_TEMPERATURE,
|
| 142 |
+
)
|
| 143 |
+
_current_llm_type = "ollama"
|
| 144 |
+
print(f"✅ 已切換到 Ollama 模型 ({OLLAMA_MODEL})")
|
| 145 |
+
return ollama_llm
|
| 146 |
+
except Exception as e:
|
| 147 |
+
print(f"⚠️ Ollama 切換失敗: {e}")
|
| 148 |
+
print(" 回退到 MLX 模型")
|
| 149 |
+
|
| 150 |
+
# 回退到 MLX 模型
|
| 151 |
_current_llm_type = "mlx"
|
| 152 |
model, tokenizer = load_mlx_model()
|
| 153 |
return MLXChatModel(
|
pyproject.toml
CHANGED
|
@@ -17,6 +17,7 @@ dependencies = [
|
|
| 17 |
"yfinance>=0.2.66",
|
| 18 |
"langgraph>=1.0.4",
|
| 19 |
"langchain-groq>=1.1.0",
|
|
|
|
| 20 |
"grandalf>=0.8",
|
| 21 |
"langserve[all]>=0.3.3",
|
| 22 |
"fastapi>=0.124.2",
|
|
@@ -36,5 +37,15 @@ dependencies = [
|
|
| 36 |
"google-auth-httplib2>=0.3.0",
|
| 37 |
"google-auth-oauthlib>=1.2.3",
|
| 38 |
"googlemaps>=4.10.0",
|
|
|
|
| 39 |
"parlant @ git+https://github.com/emcie-co/parlant@develop",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
]
|
|
|
|
| 17 |
"yfinance>=0.2.66",
|
| 18 |
"langgraph>=1.0.4",
|
| 19 |
"langchain-groq>=1.1.0",
|
| 20 |
+
"langchain-ollama>=0.1.0",
|
| 21 |
"grandalf>=0.8",
|
| 22 |
"langserve[all]>=0.3.3",
|
| 23 |
"fastapi>=0.124.2",
|
|
|
|
| 37 |
"google-auth-httplib2>=0.3.0",
|
| 38 |
"google-auth-oauthlib>=1.2.3",
|
| 39 |
"googlemaps>=4.10.0",
|
| 40 |
+
<<<<<<< HEAD
|
| 41 |
"parlant @ git+https://github.com/emcie-co/parlant@develop",
|
| 42 |
+
=======
|
| 43 |
+
# Learn_RAG 項目依賴(用於私有文件 RAG 功能)
|
| 44 |
+
"arxiv>=2.3.1",
|
| 45 |
+
"langchain-text-splitters>=0.0.1",
|
| 46 |
+
"rank-bm25>=0.2.2",
|
| 47 |
+
"chromadb>=0.4.22",
|
| 48 |
+
"docx2txt>=0.8",
|
| 49 |
+
"langchain-experimental>=0.0.50",
|
| 50 |
+
>>>>>>> 5beccbe9dfa0ef53e4123976ad54e2f1c28b72f8
|
| 51 |
]
|
src/__init__.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RAG 系統模組套件
|
| 3 |
+
"""
|
| 4 |
+
from .document_processor import DocumentProcessor
|
| 5 |
+
from .retrievers import (
|
| 6 |
+
BaseRetriever,
|
| 7 |
+
BM25Retriever,
|
| 8 |
+
VectorRetriever,
|
| 9 |
+
HybridSearch,
|
| 10 |
+
Reranker,
|
| 11 |
+
RAGPipeline,
|
| 12 |
+
)
|
| 13 |
+
from .prompt_formatter import PromptFormatter
|
| 14 |
+
from .llm_integration import OllamaLLM
|
| 15 |
+
from .subquery_rag import SubQueryDecompositionRAG
|
| 16 |
+
from .hyde_rag import HyDERAG
|
| 17 |
+
from .hybrid_subquery_hyde_rag import HybridSubqueryHyDERAG
|
| 18 |
+
from .step_back_rag import StepBackRAG
|
| 19 |
+
from .triple_hybrid_rag import TripleHybridRAG
|
| 20 |
+
|
| 21 |
+
__all__ = [
|
| 22 |
+
"DocumentProcessor",
|
| 23 |
+
"BaseRetriever",
|
| 24 |
+
"BM25Retriever",
|
| 25 |
+
"VectorRetriever",
|
| 26 |
+
"HybridSearch",
|
| 27 |
+
"Reranker",
|
| 28 |
+
"RAGPipeline",
|
| 29 |
+
"PromptFormatter",
|
| 30 |
+
"OllamaLLM",
|
| 31 |
+
"SubQueryDecompositionRAG",
|
| 32 |
+
"HyDERAG",
|
| 33 |
+
"HybridSubqueryHyDERAG",
|
| 34 |
+
"StepBackRAG",
|
| 35 |
+
"TripleHybridRAG",
|
| 36 |
+
]
|
| 37 |
+
|
src/document_processor.py
ADDED
|
@@ -0,0 +1,590 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
文檔處理模組:載入 arXiv 論文並進行文字分割
|
| 3 |
+
支援本地檔案:PDF, DOCX, TXT
|
| 4 |
+
支援兩種分塊策略:
|
| 5 |
+
1. 字符分塊(預設):基於固定字符數的分塊,速度快
|
| 6 |
+
2. 語義分塊(可選):基於語義相似度的分塊,能保持語義完整性
|
| 7 |
+
"""
|
| 8 |
+
from typing import List, Dict, Optional, Any
|
| 9 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import os
|
| 12 |
+
import arxiv
|
| 13 |
+
import re
|
| 14 |
+
|
| 15 |
+
# 嘗試導入語義分塊器(需要 langchain-experimental)
|
| 16 |
+
try:
|
| 17 |
+
from langchain_experimental.text_splitter import SemanticChunker
|
| 18 |
+
SEMANTIC_CHUNKER_AVAILABLE = True
|
| 19 |
+
except ImportError:
|
| 20 |
+
SEMANTIC_CHUNKER_AVAILABLE = False
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class DocumentProcessor:
|
| 24 |
+
"""
|
| 25 |
+
處理 arXiv 論文文檔,進行分割和準備
|
| 26 |
+
|
| 27 |
+
支援兩種分塊模式:
|
| 28 |
+
- 字符分塊(預設):快速、穩定,適合大多數場景
|
| 29 |
+
- 語義分塊(可選):更智能,能保持語義完整性,但需要額外依賴和計算時間
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
chunk_size: int = 1000,
|
| 35 |
+
chunk_overlap: int = 200,
|
| 36 |
+
embeddings: Optional[Any] = None, # 可選:用於語義分塊的 embedding 模型
|
| 37 |
+
use_semantic_chunking: bool = False, # 是否使用語義分塊
|
| 38 |
+
breakpoint_threshold_amount: float = 1.5, # 語義分塊敏感度(標準差倍數)
|
| 39 |
+
min_chunk_size: int = 100 # 語義分塊的最小 chunk 大小(字符數)
|
| 40 |
+
):
|
| 41 |
+
"""
|
| 42 |
+
初始化文檔處理器
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
chunk_size: 每個 chunk 的大小(字符數),僅用於字符分塊模式
|
| 46 |
+
chunk_overlap: chunk 之間的重疊大小(字符數),僅用於字符分塊模式
|
| 47 |
+
embeddings: 用於計算語義距離的 embedding 模型物件(可選)
|
| 48 |
+
當 use_semantic_chunking=True 時必須提供
|
| 49 |
+
use_semantic_chunking: 是否使用語義分塊
|
| 50 |
+
True: 使用語義分塊(需要提供 embeddings)
|
| 51 |
+
False: 使用字符分塊(預設)
|
| 52 |
+
breakpoint_threshold_amount: 語義分塊的敏感度參數
|
| 53 |
+
數值越大,分塊越少(chunks 越大)
|
| 54 |
+
數值越小,分塊越多(chunks 越小)
|
| 55 |
+
建議範圍:1.0 - 2.0,預設 1.5
|
| 56 |
+
min_chunk_size: 語義分塊的最小 chunk 大小(字符數)
|
| 57 |
+
小於此大小的 chunks 會被合併到相鄰的 chunks
|
| 58 |
+
預設 100 字符
|
| 59 |
+
"""
|
| 60 |
+
self.embeddings = embeddings
|
| 61 |
+
self.use_semantic_chunking = use_semantic_chunking
|
| 62 |
+
self.min_chunk_size = min_chunk_size
|
| 63 |
+
|
| 64 |
+
# 如果要求使用語義分塊
|
| 65 |
+
if use_semantic_chunking:
|
| 66 |
+
# 檢查是否安裝了必要的依賴
|
| 67 |
+
if not SEMANTIC_CHUNKER_AVAILABLE:
|
| 68 |
+
raise ImportError(
|
| 69 |
+
"使用語義分塊需要安裝 langchain-experimental 套件。\n"
|
| 70 |
+
"請執行: pip install langchain-experimental\n"
|
| 71 |
+
"或使用字符分塊模式(use_semantic_chunking=False)"
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
# 檢查是否提供了 embeddings
|
| 75 |
+
if embeddings is None:
|
| 76 |
+
raise ValueError(
|
| 77 |
+
"使用語義分塊時必須提供 embeddings 參數。\n"
|
| 78 |
+
"範例:\n"
|
| 79 |
+
" from langchain_community.embeddings import HuggingFaceEmbeddings\n"
|
| 80 |
+
" embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')\n"
|
| 81 |
+
" processor = DocumentProcessor(embeddings=embeddings, use_semantic_chunking=True)"
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# 初始化語義分塊器
|
| 85 |
+
# 使用「標準差」策略:當相鄰句子之間的語義距離超過平均距離的標準差倍數時,進行切分
|
| 86 |
+
self.text_splitter = SemanticChunker(
|
| 87 |
+
embeddings,
|
| 88 |
+
breakpoint_threshold_type="standard_deviation",
|
| 89 |
+
breakpoint_threshold_amount=breakpoint_threshold_amount
|
| 90 |
+
)
|
| 91 |
+
print(f"✓ 使用語義分塊模式(敏感度: {breakpoint_threshold_amount},最小 chunk 大小: {min_chunk_size} 字符)")
|
| 92 |
+
else:
|
| 93 |
+
# 使用傳統的字符分塊(預設模式)
|
| 94 |
+
self.text_splitter = RecursiveCharacterTextSplitter(
|
| 95 |
+
chunk_size=chunk_size,
|
| 96 |
+
chunk_overlap=chunk_overlap,
|
| 97 |
+
length_function=len,
|
| 98 |
+
)
|
| 99 |
+
print(f"✓ 使用字符分塊模式(大小: {chunk_size} 字符,重疊: {chunk_overlap} 字符)")
|
| 100 |
+
|
| 101 |
+
def _post_process_chunks(self, chunks: List[str]) -> List[str]:
|
| 102 |
+
"""
|
| 103 |
+
後處理 chunks:過濾和合併太小的 chunks
|
| 104 |
+
|
| 105 |
+
語義分塊可能會產生一些非��小的 chunks(例如只有幾個單詞),
|
| 106 |
+
這些小 chunks 可能不包含足夠的上下文資訊。此方法會:
|
| 107 |
+
1. 將小於 min_chunk_size 的 chunks 合併到相鄰的 chunks
|
| 108 |
+
2. 確保最終的 chunks 都有足夠的大小
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
chunks: 原始 chunks 列表(從分塊器產生的)
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
處理後的 chunks 列表(過濾和合併後的)
|
| 115 |
+
"""
|
| 116 |
+
# 如果使用字符分塊,不需要後處理(因為已經有固定大小)
|
| 117 |
+
if not self.use_semantic_chunking:
|
| 118 |
+
return chunks
|
| 119 |
+
|
| 120 |
+
# 如果沒有 chunks,直接返回
|
| 121 |
+
if not chunks:
|
| 122 |
+
return chunks
|
| 123 |
+
|
| 124 |
+
processed = []
|
| 125 |
+
current_small_chunk = "" # 累積的小 chunk
|
| 126 |
+
|
| 127 |
+
for chunk in chunks:
|
| 128 |
+
chunk_stripped = chunk.strip()
|
| 129 |
+
chunk_length = len(chunk_stripped)
|
| 130 |
+
|
| 131 |
+
# 如果當前 chunk 太小,嘗試與下一個合併
|
| 132 |
+
if chunk_length < self.min_chunk_size:
|
| 133 |
+
# 累積到臨時變數中
|
| 134 |
+
if current_small_chunk:
|
| 135 |
+
current_small_chunk += "\n\n" + chunk
|
| 136 |
+
else:
|
| 137 |
+
current_small_chunk = chunk
|
| 138 |
+
else:
|
| 139 |
+
# 當前 chunk 足夠大
|
| 140 |
+
# 如果有累積的小 chunk,先處理它
|
| 141 |
+
if current_small_chunk:
|
| 142 |
+
current_small_chunk_stripped = current_small_chunk.strip()
|
| 143 |
+
if len(current_small_chunk_stripped) >= self.min_chunk_size:
|
| 144 |
+
# 累積後足夠大,作為獨立 chunk
|
| 145 |
+
processed.append(current_small_chunk)
|
| 146 |
+
else:
|
| 147 |
+
# 累積後還是太小,合併到上一個 chunk(如果存在)
|
| 148 |
+
if processed:
|
| 149 |
+
processed[-1] += "\n\n" + current_small_chunk
|
| 150 |
+
else:
|
| 151 |
+
# 如果沒有上一個 chunk,還是要保留
|
| 152 |
+
processed.append(current_small_chunk)
|
| 153 |
+
current_small_chunk = ""
|
| 154 |
+
|
| 155 |
+
# 添加當前足夠大的 chunk
|
| 156 |
+
processed.append(chunk)
|
| 157 |
+
|
| 158 |
+
# 處理最後的累積小 chunk
|
| 159 |
+
if current_small_chunk:
|
| 160 |
+
current_small_chunk_stripped = current_small_chunk.strip()
|
| 161 |
+
if len(current_small_chunk_stripped) >= self.min_chunk_size:
|
| 162 |
+
# 足夠大,作為獨立 chunk
|
| 163 |
+
processed.append(current_small_chunk)
|
| 164 |
+
elif processed:
|
| 165 |
+
# 太小,合併到最後一個 chunk
|
| 166 |
+
processed[-1] += "\n\n" + current_small_chunk
|
| 167 |
+
else:
|
| 168 |
+
# 如果沒有其他 chunks,還是要保留
|
| 169 |
+
processed.append(current_small_chunk)
|
| 170 |
+
|
| 171 |
+
return processed
|
| 172 |
+
|
| 173 |
+
def fetch_papers(self, query: str, max_results: int = 10) -> List[Dict]:
|
| 174 |
+
"""
|
| 175 |
+
從 arXiv 獲取論文
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
query: 搜尋查詢(例如 "cat:cs.AI")
|
| 179 |
+
max_results: 最大結果數量
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
論文列表,每個論文包含標題、摘要等資訊
|
| 183 |
+
"""
|
| 184 |
+
search = arxiv.Search(
|
| 185 |
+
query=query,
|
| 186 |
+
max_results=max_results,
|
| 187 |
+
sort_by=arxiv.SortCriterion.SubmittedDate
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
papers = []
|
| 191 |
+
for paper in search.results():
|
| 192 |
+
papers.append({
|
| 193 |
+
"title": paper.title,
|
| 194 |
+
"authors": [author.name for author in paper.authors],
|
| 195 |
+
"summary": paper.summary,
|
| 196 |
+
"published": str(paper.published),
|
| 197 |
+
"arxiv_id": paper.entry_id.split('/')[-1],
|
| 198 |
+
"arxiv_url": paper.entry_id,
|
| 199 |
+
"pdf_url": paper.pdf_url,
|
| 200 |
+
"categories": paper.categories,
|
| 201 |
+
})
|
| 202 |
+
|
| 203 |
+
return papers
|
| 204 |
+
|
| 205 |
+
def process_documents(self, papers: List[Dict]) -> List[Dict]:
|
| 206 |
+
"""
|
| 207 |
+
處理論文,將每篇論文分割成 chunks
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
papers: 論文列表
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
處理後的文檔 chunks,每個 chunk 包含內容和元數據
|
| 214 |
+
"""
|
| 215 |
+
documents = []
|
| 216 |
+
|
| 217 |
+
for paper in papers:
|
| 218 |
+
# 組合論文的完整文字(標題 + 摘要)
|
| 219 |
+
# 保留換行符號 \n\n 作為語義斷點的結構參考
|
| 220 |
+
full_text = f"Title: {paper['title']}\n\nAbstract: {paper['summary']}"
|
| 221 |
+
|
| 222 |
+
# 分割文字(根據選擇的模式:字符分塊或語義分塊)
|
| 223 |
+
chunks = self.text_splitter.split_text(full_text)
|
| 224 |
+
|
| 225 |
+
# 後處理:過濾和合併太小的 chunks(僅語義分塊模式)
|
| 226 |
+
chunks = self._post_process_chunks(chunks)
|
| 227 |
+
|
| 228 |
+
# 為每個 chunk 創建文檔物件
|
| 229 |
+
for i, chunk in enumerate(chunks):
|
| 230 |
+
doc = {
|
| 231 |
+
"content": chunk,
|
| 232 |
+
"metadata": {
|
| 233 |
+
"title": paper['title'],
|
| 234 |
+
"arxiv_id": paper['arxiv_id'],
|
| 235 |
+
"arxiv_url": paper['arxiv_url'],
|
| 236 |
+
"pdf_url": paper['pdf_url'],
|
| 237 |
+
"authors": paper['authors'],
|
| 238 |
+
"published": paper['published'],
|
| 239 |
+
"categories": paper['categories'],
|
| 240 |
+
"chunk_index": i,
|
| 241 |
+
"total_chunks": len(chunks),
|
| 242 |
+
"chunking_method": "semantic" if self.use_semantic_chunking else "character"
|
| 243 |
+
}
|
| 244 |
+
}
|
| 245 |
+
documents.append(doc)
|
| 246 |
+
|
| 247 |
+
return documents
|
| 248 |
+
|
| 249 |
+
def get_texts_and_metadatas(self, documents: List[Dict]):
|
| 250 |
+
"""
|
| 251 |
+
從文檔列表中提取文字和元數據
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
documents: 文檔列表
|
| 255 |
+
|
| 256 |
+
Returns:
|
| 257 |
+
(texts, metadatas) 元組
|
| 258 |
+
"""
|
| 259 |
+
texts = [doc["content"] for doc in documents]
|
| 260 |
+
metadatas = [doc["metadata"] for doc in documents]
|
| 261 |
+
return texts, metadatas
|
| 262 |
+
|
| 263 |
+
@staticmethod
|
| 264 |
+
def clean_extracted_text(text: str) -> str:
|
| 265 |
+
"""
|
| 266 |
+
清理從 PDF/DOCX 提取的文本,移除多餘的空格和修復字符換行問題
|
| 267 |
+
|
| 268 |
+
某些 PDF 提取工具會在每個字符之間插入空格或換行,特別是中文文本。
|
| 269 |
+
此方法會:
|
| 270 |
+
1. 修復「每個字符一行」的問題(將單字符行合併)
|
| 271 |
+
2. 移除中文字符之間的多餘空格
|
| 272 |
+
3. 保留英文單詞之間的空格
|
| 273 |
+
4. 保留標點符號周圍的適當空格
|
| 274 |
+
5. 保留真正的段落分隔
|
| 275 |
+
|
| 276 |
+
Args:
|
| 277 |
+
text: 原始提取的文本
|
| 278 |
+
|
| 279 |
+
Returns:
|
| 280 |
+
清理後的文本
|
| 281 |
+
"""
|
| 282 |
+
if not text:
|
| 283 |
+
return text
|
| 284 |
+
|
| 285 |
+
# 步驟 0: 修復「每個字符一行」的問題
|
| 286 |
+
# 檢測模式:每行只有一個字符(可能是中文字符、標點、或單個字母/數字)
|
| 287 |
+
# 將這些單字符行合併成連續文本
|
| 288 |
+
lines = text.split('\n')
|
| 289 |
+
merged_lines = []
|
| 290 |
+
i = 0
|
| 291 |
+
|
| 292 |
+
def is_single_char_line(line: str) -> bool:
|
| 293 |
+
"""
|
| 294 |
+
判斷是否為單字符行
|
| 295 |
+
考慮:去除空格後長度 <= 3(可能是單字符+標點,或單字符+空格)
|
| 296 |
+
"""
|
| 297 |
+
stripped = line.strip()
|
| 298 |
+
if not stripped:
|
| 299 |
+
return False # 空行不算
|
| 300 |
+
# 如果去除空格後長度 <= 3,且主要是中文字符、標點或單個字母/數字
|
| 301 |
+
if len(stripped) <= 3:
|
| 302 |
+
# 檢查是否主要是單個字符(可能帶標點或空格)
|
| 303 |
+
# 移除所有空格後,如果長度 <= 2,認為是單字符行
|
| 304 |
+
no_space = stripped.replace(' ', '')
|
| 305 |
+
if len(no_space) <= 2:
|
| 306 |
+
return True
|
| 307 |
+
return False
|
| 308 |
+
|
| 309 |
+
while i < len(lines):
|
| 310 |
+
line = lines[i]
|
| 311 |
+
stripped_line = line.strip()
|
| 312 |
+
|
| 313 |
+
# 如果當前行是單字符行
|
| 314 |
+
if is_single_char_line(line):
|
| 315 |
+
# 收集連續的單字符行(包括空行,因為空行可能是分隔符)
|
| 316 |
+
merged_chars = []
|
| 317 |
+
j = i
|
| 318 |
+
consecutive_single_chars = 0
|
| 319 |
+
|
| 320 |
+
while j < len(lines):
|
| 321 |
+
current_line = lines[j]
|
| 322 |
+
current_stripped = current_line.strip()
|
| 323 |
+
|
| 324 |
+
if is_single_char_line(current_line):
|
| 325 |
+
# 是單字符行,收集字符(去除空格)
|
| 326 |
+
char = current_stripped.replace(' ', '')
|
| 327 |
+
if char:
|
| 328 |
+
merged_chars.append(char)
|
| 329 |
+
consecutive_single_chars += 1
|
| 330 |
+
j += 1
|
| 331 |
+
elif not current_stripped:
|
| 332 |
+
# 空行:如果前面有單字符,且後面可能還有單字符,跳過空行
|
| 333 |
+
# 檢查下一行是否也是單字符
|
| 334 |
+
if j + 1 < len(lines) and is_single_char_line(lines[j + 1]):
|
| 335 |
+
# 空行後面還有單字符,跳過空行繼續收集
|
| 336 |
+
j += 1
|
| 337 |
+
else:
|
| 338 |
+
# 空行後面沒有單字符了,停止收集
|
| 339 |
+
break
|
| 340 |
+
else:
|
| 341 |
+
# 遇到正常行,停止收集
|
| 342 |
+
break
|
| 343 |
+
|
| 344 |
+
# 如果收集到多個單字符,合併它們
|
| 345 |
+
if len(merged_chars) > 1:
|
| 346 |
+
merged_text = ''.join(merged_chars)
|
| 347 |
+
merged_lines.append(merged_text)
|
| 348 |
+
i = j
|
| 349 |
+
continue
|
| 350 |
+
elif len(merged_chars) == 1 and consecutive_single_chars > 1:
|
| 351 |
+
# 只有一個字符但有多行(可能是空格導致的),也合併
|
| 352 |
+
merged_text = ''.join(merged_chars)
|
| 353 |
+
merged_lines.append(merged_text)
|
| 354 |
+
i = j
|
| 355 |
+
continue
|
| 356 |
+
else:
|
| 357 |
+
# 只有一個單字符,且確實只有一行,保留原樣
|
| 358 |
+
if merged_chars:
|
| 359 |
+
merged_lines.append(merged_chars[0])
|
| 360 |
+
i = j
|
| 361 |
+
continue
|
| 362 |
+
else:
|
| 363 |
+
# 正常行,直接添加
|
| 364 |
+
if stripped_line: # 非空行
|
| 365 |
+
merged_lines.append(stripped_line)
|
| 366 |
+
i += 1
|
| 367 |
+
|
| 368 |
+
# 重新組合文本
|
| 369 |
+
text = '\n'.join(merged_lines)
|
| 370 |
+
|
| 371 |
+
# 步驟 0.5: 再次處理可能的殘留問題
|
| 372 |
+
# 如果還有單字符行(可能是第一次處理遺漏的),再次處理
|
| 373 |
+
lines = text.split('\n')
|
| 374 |
+
final_lines = []
|
| 375 |
+
i = 0
|
| 376 |
+
while i < len(lines):
|
| 377 |
+
line = lines[i].strip()
|
| 378 |
+
if is_single_char_line(line):
|
| 379 |
+
# 再次收集連續的單字符行
|
| 380 |
+
merged_chars = []
|
| 381 |
+
j = i
|
| 382 |
+
while j < len(lines) and is_single_char_line(lines[j]):
|
| 383 |
+
char = lines[j].strip().replace(' ', '')
|
| 384 |
+
if char:
|
| 385 |
+
merged_chars.append(char)
|
| 386 |
+
j += 1
|
| 387 |
+
|
| 388 |
+
if len(merged_chars) > 1:
|
| 389 |
+
final_lines.append(''.join(merged_chars))
|
| 390 |
+
i = j
|
| 391 |
+
else:
|
| 392 |
+
if merged_chars:
|
| 393 |
+
final_lines.append(merged_chars[0])
|
| 394 |
+
i = j
|
| 395 |
+
else:
|
| 396 |
+
if line:
|
| 397 |
+
final_lines.append(line)
|
| 398 |
+
i += 1
|
| 399 |
+
|
| 400 |
+
text = '\n'.join(final_lines)
|
| 401 |
+
|
| 402 |
+
# 1. 移除中文字符之間的空格
|
| 403 |
+
# 匹配模式:中文字符 + 空格 + 中文字符
|
| 404 |
+
chinese_char_pattern = r'([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])\s+([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])'
|
| 405 |
+
text = re.sub(chinese_char_pattern, r'\1\2', text)
|
| 406 |
+
|
| 407 |
+
# 2. 移除中文和標點符號之間的多餘空格
|
| 408 |
+
# 中文 + 空格 + 標點符號
|
| 409 |
+
chinese_punct_pattern = r'([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])\s+([,。、;:!?""''()【】《》])'
|
| 410 |
+
text = re.sub(chinese_punct_pattern, r'\1\2', text)
|
| 411 |
+
|
| 412 |
+
# 標點符號 + 空格 + 中文
|
| 413 |
+
# 使用 re.escape 來正確處理標點符號,避免轉義序列警告
|
| 414 |
+
punct_chars = ',。、;:!?""''()【】《》'
|
| 415 |
+
punct_chinese_pattern = f'([{re.escape(punct_chars)}])\\s+([\\u4e00-\\u9fff\\u3400-\\u4dbf\\uf900-\\ufaff])'
|
| 416 |
+
text = re.sub(punct_chinese_pattern, r'\1\2', text)
|
| 417 |
+
|
| 418 |
+
# 3. 移除數字和中文之間的多餘空格(例如:"500 公里" -> "500公里")
|
| 419 |
+
number_chinese_pattern = r'(\d+)\s+([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])'
|
| 420 |
+
text = re.sub(number_chinese_pattern, r'\1\2', text)
|
| 421 |
+
chinese_number_pattern = r'([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])\s+(\d+)'
|
| 422 |
+
text = re.sub(chinese_number_pattern, r'\1\2', text)
|
| 423 |
+
|
| 424 |
+
# 4. 移除英文單詞內部的多餘空格(例如:"Nebula-X 跨次 元量" -> "Nebula-X 跨次元量")
|
| 425 |
+
# 但保留英文單詞之間的空格
|
| 426 |
+
# 匹配:非空格字符 + 空格 + 非空格字符(如果其中一個是中文,則移除空格)
|
| 427 |
+
mixed_space_pattern = r'([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])\s+([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])'
|
| 428 |
+
text = re.sub(mixed_space_pattern, r'\1\2', text)
|
| 429 |
+
|
| 430 |
+
# 5. 移除多個連續空格(保留單個空格,用於英文單詞之間)
|
| 431 |
+
text = re.sub(r' +', ' ', text)
|
| 432 |
+
|
| 433 |
+
# 6. 清理行首行尾的空格(但保留換行符)
|
| 434 |
+
lines = text.split('\n')
|
| 435 |
+
cleaned_lines = [line.strip() for line in lines]
|
| 436 |
+
text = '\n'.join(cleaned_lines)
|
| 437 |
+
|
| 438 |
+
# 7. 移除多個連續的換行符(保留最多兩個,用於段落分隔)
|
| 439 |
+
text = re.sub(r'\n{3,}', '\n\n', text)
|
| 440 |
+
|
| 441 |
+
# 8. 修復可能的殘留問題:移除中文字符之間殘留的空格
|
| 442 |
+
# 再次檢查並移除中文字符之間的空格(處理可能遺漏的情況)
|
| 443 |
+
text = re.sub(r'([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])\s+([\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff])', r'\1\2', text)
|
| 444 |
+
|
| 445 |
+
return text
|
| 446 |
+
|
| 447 |
+
def load_from_file(self, file_path: str) -> Dict:
|
| 448 |
+
"""
|
| 449 |
+
從本地檔案載入文檔(支援 PDF, DOCX, TXT 等)
|
| 450 |
+
|
| 451 |
+
Args:
|
| 452 |
+
file_path: 檔案路徑
|
| 453 |
+
|
| 454 |
+
Returns:
|
| 455 |
+
文檔字典,包含內容和元數據
|
| 456 |
+
"""
|
| 457 |
+
file_path = Path(file_path)
|
| 458 |
+
|
| 459 |
+
if not file_path.exists():
|
| 460 |
+
raise FileNotFoundError(f"檔案不存在: {file_path}")
|
| 461 |
+
|
| 462 |
+
file_ext = file_path.suffix.lower()
|
| 463 |
+
file_name = file_path.stem
|
| 464 |
+
file_size = os.path.getsize(file_path)
|
| 465 |
+
|
| 466 |
+
# 根據檔案類型選擇不同的加載器
|
| 467 |
+
if file_ext == '.pdf':
|
| 468 |
+
try:
|
| 469 |
+
from langchain_community.document_loaders import PyPDFLoader
|
| 470 |
+
loader = PyPDFLoader(str(file_path))
|
| 471 |
+
pages = loader.load()
|
| 472 |
+
# 合併所有頁面
|
| 473 |
+
full_text = "\n\n".join([page.page_content for page in pages])
|
| 474 |
+
# 清理提取的文本(移除多餘空格)
|
| 475 |
+
full_text = self.clean_extracted_text(full_text)
|
| 476 |
+
except ImportError:
|
| 477 |
+
raise ImportError(
|
| 478 |
+
"需要安裝 pypdf 來處理 PDF 檔案: pip install pypdf"
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
elif file_ext in ['.docx', '.doc']:
|
| 482 |
+
try:
|
| 483 |
+
from langchain_community.document_loaders import Docx2txtLoader
|
| 484 |
+
loader = Docx2txtLoader(str(file_path))
|
| 485 |
+
pages = loader.load()
|
| 486 |
+
full_text = "\n\n".join([page.page_content for page in pages])
|
| 487 |
+
# 清理提取的文本(移除多餘空格)
|
| 488 |
+
full_text = self.clean_extracted_text(full_text)
|
| 489 |
+
except ImportError:
|
| 490 |
+
raise ImportError(
|
| 491 |
+
"需要安裝 docx2txt 來處理 DOCX 檔案: pip install docx2txt"
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
elif file_ext == '.txt':
|
| 495 |
+
# 嘗試不同的編碼
|
| 496 |
+
encodings = ['utf-8', 'gbk', 'big5', 'latin-1']
|
| 497 |
+
full_text = None
|
| 498 |
+
for encoding in encodings:
|
| 499 |
+
try:
|
| 500 |
+
with open(file_path, 'r', encoding=encoding) as f:
|
| 501 |
+
full_text = f.read()
|
| 502 |
+
break
|
| 503 |
+
except UnicodeDecodeError:
|
| 504 |
+
continue
|
| 505 |
+
|
| 506 |
+
if full_text is None:
|
| 507 |
+
raise ValueError(f"無法讀取檔案,嘗試的編碼都不適用: {encodings}")
|
| 508 |
+
|
| 509 |
+
else:
|
| 510 |
+
raise ValueError(
|
| 511 |
+
f"不支援的檔案類型: {file_ext}\n"
|
| 512 |
+
f"支援的格式: .pdf, .docx, .doc, .txt"
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
if not full_text or len(full_text.strip()) == 0:
|
| 516 |
+
raise ValueError(f"檔案為空或無法提取文字: {file_path}")
|
| 517 |
+
|
| 518 |
+
return {
|
| 519 |
+
"title": file_name,
|
| 520 |
+
"content": full_text,
|
| 521 |
+
"file_path": str(file_path),
|
| 522 |
+
"file_type": file_ext,
|
| 523 |
+
"file_size": file_size,
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
def process_file(self, file_path: str) -> List[Dict]:
|
| 527 |
+
"""
|
| 528 |
+
處理單個檔案,分割成 chunks
|
| 529 |
+
|
| 530 |
+
Args:
|
| 531 |
+
file_path: 檔案路徑
|
| 532 |
+
|
| 533 |
+
Returns:
|
| 534 |
+
處理後的文檔 chunks 列表
|
| 535 |
+
"""
|
| 536 |
+
# 載入檔案
|
| 537 |
+
file_doc = self.load_from_file(file_path)
|
| 538 |
+
|
| 539 |
+
# 分割文字(根據選擇的模式:字符分塊或語義分塊)
|
| 540 |
+
chunks = self.text_splitter.split_text(file_doc["content"])
|
| 541 |
+
|
| 542 |
+
# 後處理:過濾和合併太小的 chunks(僅語義分塊模式)
|
| 543 |
+
chunks = self._post_process_chunks(chunks)
|
| 544 |
+
|
| 545 |
+
if not chunks:
|
| 546 |
+
raise ValueError(f"檔案分割後沒有內容: {file_path}")
|
| 547 |
+
|
| 548 |
+
# 創建文檔 chunks
|
| 549 |
+
documents = []
|
| 550 |
+
for i, chunk in enumerate(chunks):
|
| 551 |
+
doc = {
|
| 552 |
+
"content": chunk,
|
| 553 |
+
"metadata": {
|
| 554 |
+
"title": file_doc["title"],
|
| 555 |
+
"file_path": file_doc["file_path"],
|
| 556 |
+
"file_type": file_doc["file_type"],
|
| 557 |
+
"file_size": file_doc["file_size"],
|
| 558 |
+
"chunk_index": i,
|
| 559 |
+
"total_chunks": len(chunks),
|
| 560 |
+
"chunking_method": "semantic" if self.use_semantic_chunking else "character"
|
| 561 |
+
}
|
| 562 |
+
}
|
| 563 |
+
documents.append(doc)
|
| 564 |
+
|
| 565 |
+
return documents
|
| 566 |
+
|
| 567 |
+
def process_files(self, file_paths: List[str]) -> List[Dict]:
|
| 568 |
+
"""
|
| 569 |
+
處理多個檔案
|
| 570 |
+
|
| 571 |
+
Args:
|
| 572 |
+
file_paths: 檔案路徑列表
|
| 573 |
+
|
| 574 |
+
Returns:
|
| 575 |
+
所有檔案的文檔 chunks 列表
|
| 576 |
+
"""
|
| 577 |
+
all_documents = []
|
| 578 |
+
for file_path in file_paths:
|
| 579 |
+
try:
|
| 580 |
+
print(f"處理檔案: {file_path}")
|
| 581 |
+
documents = self.process_file(file_path)
|
| 582 |
+
all_documents.extend(documents)
|
| 583 |
+
print(f" ✓ 創建了 {len(documents)} 個 chunks")
|
| 584 |
+
except Exception as e:
|
| 585 |
+
print(f" ✗ 處理檔案失敗: {file_path}")
|
| 586 |
+
print(f" 錯誤: {e}")
|
| 587 |
+
continue
|
| 588 |
+
|
| 589 |
+
return all_documents
|
| 590 |
+
|
src/hybrid_subquery_hyde_rag.py
ADDED
|
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hybrid Sub-query + HyDE RAG:融合 Sub-query Decomposition 和 HyDE
|
| 3 |
+
結合兩種方法的優勢,提升檢索精度
|
| 4 |
+
"""
|
| 5 |
+
from typing import List, Dict, Optional
|
| 6 |
+
from .retrievers.reranker import RAGPipeline
|
| 7 |
+
from .retrievers.vector_retriever import VectorRetriever
|
| 8 |
+
from .prompt_formatter import PromptFormatter
|
| 9 |
+
from .llm_integration import OllamaLLM
|
| 10 |
+
import hashlib
|
| 11 |
+
import time
|
| 12 |
+
import logging
|
| 13 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class HybridSubqueryHyDERAG:
|
| 19 |
+
"""融合 Sub-query Decomposition 和 HyDE 的 RAG 系統"""
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
rag_pipeline: RAGPipeline,
|
| 24 |
+
vector_retriever: VectorRetriever,
|
| 25 |
+
llm: OllamaLLM,
|
| 26 |
+
max_sub_queries: int = 3,
|
| 27 |
+
top_k_per_subquery: int = 5,
|
| 28 |
+
hypothetical_length: int = 200,
|
| 29 |
+
temperature_subquery: float = 0.3,
|
| 30 |
+
temperature_hyde: float = 0.7,
|
| 31 |
+
enable_parallel: bool = True
|
| 32 |
+
):
|
| 33 |
+
"""
|
| 34 |
+
初始化融合 RAG
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
rag_pipeline: RAG 管線實例
|
| 38 |
+
vector_retriever: 向量檢索器
|
| 39 |
+
llm: LLM 實例
|
| 40 |
+
max_sub_queries: 最多生成的子問題數量
|
| 41 |
+
top_k_per_subquery: 每個子問題檢索的結果數量
|
| 42 |
+
hypothetical_length: 假設性文檔目標長度(字符數)
|
| 43 |
+
temperature_subquery: 生成子問題的溫度(較低,更穩定)
|
| 44 |
+
temperature_hyde: 生成假設性文檔的溫度(較高,更多專業術語)
|
| 45 |
+
enable_parallel: 是否並行處理
|
| 46 |
+
"""
|
| 47 |
+
self.rag_pipeline = rag_pipeline
|
| 48 |
+
self.vector_retriever = vector_retriever
|
| 49 |
+
self.llm = llm
|
| 50 |
+
self.max_sub_queries = max_sub_queries
|
| 51 |
+
self.top_k_per_subquery = top_k_per_subquery
|
| 52 |
+
self.hypothetical_length = hypothetical_length
|
| 53 |
+
self.temperature_subquery = temperature_subquery
|
| 54 |
+
self.temperature_hyde = temperature_hyde
|
| 55 |
+
self.enable_parallel = enable_parallel
|
| 56 |
+
|
| 57 |
+
def _generate_sub_queries(self, question: str) -> List[str]:
|
| 58 |
+
"""
|
| 59 |
+
生成子問題(與 SubQueryDecompositionRAG 相同)
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
question: 原始問題
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
子問題列表
|
| 66 |
+
"""
|
| 67 |
+
is_chinese = PromptFormatter.detect_language(question) == "zh"
|
| 68 |
+
|
| 69 |
+
if is_chinese:
|
| 70 |
+
prompt = f"""你是一個專業助理。請將以下原始問題拆解成最多 {self.max_sub_queries} 個具體的子問題,以便進行資料搜尋。
|
| 71 |
+
每個子問題應專注於原始問題的一個特定面向。請以換行符號分隔問題。
|
| 72 |
+
|
| 73 |
+
原始問題: {question}
|
| 74 |
+
|
| 75 |
+
子問題清單:"""
|
| 76 |
+
else:
|
| 77 |
+
prompt = f"""You are a professional assistant. Please decompose the following original question into at most {self.max_sub_queries} specific sub-questions for information retrieval.
|
| 78 |
+
Each sub-question should focus on a specific aspect of the original question. Please separate questions with newlines.
|
| 79 |
+
|
| 80 |
+
Original question: {question}
|
| 81 |
+
|
| 82 |
+
Sub-question list:"""
|
| 83 |
+
|
| 84 |
+
try:
|
| 85 |
+
response = self.llm.generate(
|
| 86 |
+
prompt=prompt,
|
| 87 |
+
temperature=self.temperature_subquery,
|
| 88 |
+
max_tokens=500
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
sub_queries = [
|
| 92 |
+
q.strip()
|
| 93 |
+
for q in response.strip().split("\n")
|
| 94 |
+
if q.strip() and not q.strip().startswith("#")
|
| 95 |
+
]
|
| 96 |
+
|
| 97 |
+
# 移除編號前綴(如 "1. ", "1) " 等)
|
| 98 |
+
cleaned_queries = []
|
| 99 |
+
for q in sub_queries:
|
| 100 |
+
q = q.lstrip("0123456789. )")
|
| 101 |
+
q = q.strip()
|
| 102 |
+
if q:
|
| 103 |
+
cleaned_queries.append(q)
|
| 104 |
+
|
| 105 |
+
cleaned_queries = cleaned_queries[:self.max_sub_queries]
|
| 106 |
+
|
| 107 |
+
if not cleaned_queries:
|
| 108 |
+
logger.warning("⚠️ 未生成子問題,使用原始問題")
|
| 109 |
+
cleaned_queries = [question]
|
| 110 |
+
|
| 111 |
+
return cleaned_queries
|
| 112 |
+
|
| 113 |
+
except Exception as e:
|
| 114 |
+
logger.error(f"⚠️ 生成子問題時出錯: {e}")
|
| 115 |
+
return [question]
|
| 116 |
+
|
| 117 |
+
def _generate_hypothetical_document(self, sub_query: str) -> str:
|
| 118 |
+
"""
|
| 119 |
+
為子問題生成假設性文檔(與 HyDERAG 相同)
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
sub_query: 子問題
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
假設性文檔文本
|
| 126 |
+
"""
|
| 127 |
+
is_chinese = PromptFormatter.detect_language(sub_query) == "zh"
|
| 128 |
+
|
| 129 |
+
if is_chinese:
|
| 130 |
+
prompt = f"""請針對以下問題,寫出一段約 {self.hypothetical_length} 字的專業技術檔案內容。
|
| 131 |
+
這段內容應包含該領域常見的專業術語與原理說明,以便用於後續的語義檢索。
|
| 132 |
+
請使用專業的術語和概念,即使你對某些細節不確定,也要包含相關的專業詞彙。
|
| 133 |
+
|
| 134 |
+
問題: {sub_query}
|
| 135 |
+
|
| 136 |
+
專業技術內容:"""
|
| 137 |
+
else:
|
| 138 |
+
prompt = f"""Please write a professional technical document of approximately {self.hypothetical_length} words in response to the following question.
|
| 139 |
+
This content should include common professional terminology and principle explanations in this field, to be used for subsequent semantic retrieval.
|
| 140 |
+
Please use professional terms and concepts, and include relevant professional vocabulary even if you are uncertain about some details.
|
| 141 |
+
|
| 142 |
+
Question: {sub_query}
|
| 143 |
+
|
| 144 |
+
Professional technical content:"""
|
| 145 |
+
|
| 146 |
+
try:
|
| 147 |
+
hypothetical_doc = self.llm.generate(
|
| 148 |
+
prompt=prompt,
|
| 149 |
+
temperature=self.temperature_hyde,
|
| 150 |
+
max_tokens=500
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
hypothetical_doc = hypothetical_doc.strip()
|
| 154 |
+
|
| 155 |
+
if not hypothetical_doc:
|
| 156 |
+
logger.warning(f"⚠️ 子問題 '{sub_query}' 的假設性文檔為空,使用子問題本身")
|
| 157 |
+
return sub_query
|
| 158 |
+
|
| 159 |
+
logger.debug(f"✅ 為子問題生成假設性文檔(長度: {len(hypothetical_doc)} 字符)")
|
| 160 |
+
return hypothetical_doc
|
| 161 |
+
|
| 162 |
+
except Exception as e:
|
| 163 |
+
logger.error(f"⚠️ 生成假設性文檔時出錯: {e}")
|
| 164 |
+
return sub_query
|
| 165 |
+
|
| 166 |
+
def _get_doc_id(self, doc: Dict) -> str:
|
| 167 |
+
"""
|
| 168 |
+
生成文檔的唯一標識符
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
doc: 文檔字典
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
唯一 ID
|
| 175 |
+
"""
|
| 176 |
+
metadata = doc.get("metadata", {})
|
| 177 |
+
content = doc.get("content", "")
|
| 178 |
+
|
| 179 |
+
if "arxiv_id" in metadata and "chunk_index" in metadata:
|
| 180 |
+
return f"{metadata['arxiv_id']}_{metadata['chunk_index']}"
|
| 181 |
+
elif "file_path" in metadata and "chunk_index" in metadata:
|
| 182 |
+
return f"{metadata['file_path']}_{metadata['chunk_index']}"
|
| 183 |
+
else:
|
| 184 |
+
content_hash = hashlib.md5(content.encode()).hexdigest()[:16]
|
| 185 |
+
return f"doc_{content_hash}"
|
| 186 |
+
|
| 187 |
+
def _process_subquery_with_hyde(
|
| 188 |
+
self,
|
| 189 |
+
sub_query: str,
|
| 190 |
+
metadata_filter: Optional[Dict] = None
|
| 191 |
+
) -> tuple:
|
| 192 |
+
"""
|
| 193 |
+
處理單個子問題:生成假設性文檔並檢索
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
sub_query: 子問題
|
| 197 |
+
metadata_filter: 可選的 metadata 過濾條件
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
(檢索結果列表, 假設性文檔)
|
| 201 |
+
"""
|
| 202 |
+
try:
|
| 203 |
+
# 生成假設性文檔
|
| 204 |
+
hypothetical_doc = self._generate_hypothetical_document(sub_query)
|
| 205 |
+
|
| 206 |
+
# 使用假設性文檔檢索
|
| 207 |
+
results = self.vector_retriever.retrieve(
|
| 208 |
+
query=hypothetical_doc, # 使用假設性文檔而不是子問題
|
| 209 |
+
top_k=self.top_k_per_subquery,
|
| 210 |
+
metadata_filter=metadata_filter
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
return results, hypothetical_doc
|
| 214 |
+
|
| 215 |
+
except Exception as e:
|
| 216 |
+
logger.error(f"⚠️ 處理子問題 '{sub_query}' 時出錯: {e}")
|
| 217 |
+
return [], ""
|
| 218 |
+
|
| 219 |
+
def query(
|
| 220 |
+
self,
|
| 221 |
+
question: str,
|
| 222 |
+
top_k: int = 5,
|
| 223 |
+
metadata_filter: Optional[Dict] = None,
|
| 224 |
+
return_sub_queries: bool = False,
|
| 225 |
+
return_hypothetical: bool = False
|
| 226 |
+
) -> Dict:
|
| 227 |
+
"""
|
| 228 |
+
執行融合 RAG 檢索(不生成答案)
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
question: 原始問題
|
| 232 |
+
top_k: 返回前 k 個結果
|
| 233 |
+
metadata_filter: 可選的 metadata 過濾條件
|
| 234 |
+
return_sub_queries: 是否返回子問題列表
|
| 235 |
+
return_hypothetical: 是否返回假設性文檔字典(子問題 -> 假設性文檔)
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
包含檢索結果和統計資訊的字典
|
| 239 |
+
"""
|
| 240 |
+
start_time = time.time()
|
| 241 |
+
|
| 242 |
+
# 第一步:生成子問題
|
| 243 |
+
logger.info(f"🔍 拆解問題: '{question}'")
|
| 244 |
+
sub_queries = self._generate_sub_queries(question)
|
| 245 |
+
logger.info(f"✅ 生成 {len(sub_queries)} 個子問題")
|
| 246 |
+
|
| 247 |
+
# 第二步:為每個子問題生成假設性文檔並檢索
|
| 248 |
+
logger.info(f"📚 為每個子問題生成假設性文檔並檢索...")
|
| 249 |
+
unique_docs = {}
|
| 250 |
+
hypothetical_docs = {}
|
| 251 |
+
|
| 252 |
+
if self.enable_parallel and len(sub_queries) > 1:
|
| 253 |
+
# 並行處理
|
| 254 |
+
logger.info(f"🔄 並行處理 {len(sub_queries)} 個子問題...")
|
| 255 |
+
with ThreadPoolExecutor(max_workers=min(len(sub_queries), 5)) as executor:
|
| 256 |
+
future_to_query = {
|
| 257 |
+
executor.submit(self._process_subquery_with_hyde, sq, metadata_filter): sq
|
| 258 |
+
for sq in sub_queries
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
for future in as_completed(future_to_query):
|
| 262 |
+
sub_query = future_to_query[future]
|
| 263 |
+
try:
|
| 264 |
+
results, hypo_doc = future.result()
|
| 265 |
+
hypothetical_docs[sub_query] = hypo_doc
|
| 266 |
+
|
| 267 |
+
logger.debug(f"✅ 子問題 '{sub_query}' 找到 {len(results)} 個結果")
|
| 268 |
+
|
| 269 |
+
for doc in results:
|
| 270 |
+
doc_id = self._get_doc_id(doc)
|
| 271 |
+
if doc_id not in unique_docs:
|
| 272 |
+
unique_docs[doc_id] = doc
|
| 273 |
+
else:
|
| 274 |
+
# 保留分數更高的
|
| 275 |
+
existing_score = unique_docs[doc_id].get('score', 0)
|
| 276 |
+
new_score = doc.get('score', 0)
|
| 277 |
+
if new_score > existing_score:
|
| 278 |
+
unique_docs[doc_id] = doc
|
| 279 |
+
except Exception as e:
|
| 280 |
+
logger.error(f"⚠️ 處理子問題 '{sub_query}' 時出錯: {e}")
|
| 281 |
+
else:
|
| 282 |
+
# 串行處理
|
| 283 |
+
logger.info(f"🔄 串行處理 {len(sub_queries)} 個子問題...")
|
| 284 |
+
for sub_query in sub_queries:
|
| 285 |
+
results, hypo_doc = self._process_subquery_with_hyde(sub_query, metadata_filter)
|
| 286 |
+
hypothetical_docs[sub_query] = hypo_doc
|
| 287 |
+
|
| 288 |
+
logger.debug(f"✅ 子問題 '{sub_query}' 找到 {len(results)} 個結果")
|
| 289 |
+
|
| 290 |
+
for doc in results:
|
| 291 |
+
doc_id = self._get_doc_id(doc)
|
| 292 |
+
if doc_id not in unique_docs:
|
| 293 |
+
unique_docs[doc_id] = doc
|
| 294 |
+
else:
|
| 295 |
+
existing_score = unique_docs[doc_id].get('score', 0)
|
| 296 |
+
new_score = doc.get('score', 0)
|
| 297 |
+
if new_score > existing_score:
|
| 298 |
+
unique_docs[doc_id] = doc
|
| 299 |
+
|
| 300 |
+
# 第三步:排序並返回前 top_k
|
| 301 |
+
result_list = list(unique_docs.values())
|
| 302 |
+
result_list.sort(key=lambda x: x.get('score', 0), reverse=True)
|
| 303 |
+
final_results = result_list[:top_k]
|
| 304 |
+
|
| 305 |
+
elapsed_time = time.time() - start_time
|
| 306 |
+
logger.info(f"✅ 找到 {len(final_results)} 個唯一文檔(去重後,總共 {len(result_list)} 個)")
|
| 307 |
+
|
| 308 |
+
return {
|
| 309 |
+
"results": final_results,
|
| 310 |
+
"total_docs_found": len(result_list),
|
| 311 |
+
"sub_queries": sub_queries if return_sub_queries else None,
|
| 312 |
+
"hypothetical_documents": hypothetical_docs if return_hypothetical else None,
|
| 313 |
+
"elapsed_time": elapsed_time
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
def generate_answer(
|
| 317 |
+
self,
|
| 318 |
+
question: str,
|
| 319 |
+
formatter: PromptFormatter,
|
| 320 |
+
top_k: int = 5,
|
| 321 |
+
metadata_filter: Optional[Dict] = None,
|
| 322 |
+
document_type: str = "general",
|
| 323 |
+
return_sub_queries: bool = False,
|
| 324 |
+
return_hypothetical: bool = False
|
| 325 |
+
) -> Dict:
|
| 326 |
+
"""
|
| 327 |
+
完整的融合 RAG 流程:檢索 + 生成答案
|
| 328 |
+
|
| 329 |
+
Args:
|
| 330 |
+
question: 原始問題
|
| 331 |
+
formatter: Prompt 格式化器
|
| 332 |
+
top_k: 用於生成答案的文檔數量
|
| 333 |
+
metadata_filter: 可選的 metadata 過濾條件
|
| 334 |
+
document_type: 文檔類型 ("paper", "cv", "general")
|
| 335 |
+
return_sub_queries: 是否返回子問題列表
|
| 336 |
+
return_hypothetical: 是否返回假設性文檔字典
|
| 337 |
+
|
| 338 |
+
Returns:
|
| 339 |
+
包含檢索結果、生成的答案和統計資訊的字典
|
| 340 |
+
"""
|
| 341 |
+
start_time = time.time()
|
| 342 |
+
|
| 343 |
+
# 檢索
|
| 344 |
+
retrieval_result = self.query(
|
| 345 |
+
question=question,
|
| 346 |
+
top_k=top_k,
|
| 347 |
+
metadata_filter=metadata_filter,
|
| 348 |
+
return_sub_queries=return_sub_queries,
|
| 349 |
+
return_hypothetical=return_hypothetical
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
if not retrieval_result["results"]:
|
| 353 |
+
return {
|
| 354 |
+
**retrieval_result,
|
| 355 |
+
"answer": "抱歉,未找到相關文檔來回答此問題。",
|
| 356 |
+
"formatted_context": None,
|
| 357 |
+
"answer_time": 0.0,
|
| 358 |
+
"total_time": retrieval_result["elapsed_time"]
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
# 格式化上下文
|
| 362 |
+
formatted_context = formatter.format_context(
|
| 363 |
+
retrieval_result["results"],
|
| 364 |
+
document_type=document_type
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
# 創建 prompt(使用原始問題)
|
| 368 |
+
prompt = formatter.create_prompt(
|
| 369 |
+
question,
|
| 370 |
+
formatted_context,
|
| 371 |
+
document_type=document_type
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
# 生成回答
|
| 375 |
+
logger.info("🤖 生成回答中...")
|
| 376 |
+
answer_start = time.time()
|
| 377 |
+
try:
|
| 378 |
+
answer = self.llm.generate(
|
| 379 |
+
prompt=prompt,
|
| 380 |
+
temperature=0.7,
|
| 381 |
+
max_tokens=2048
|
| 382 |
+
)
|
| 383 |
+
answer_time = time.time() - answer_start
|
| 384 |
+
logger.info(f"✅ 回答生成完成(耗時: {answer_time:.2f}s)")
|
| 385 |
+
except Exception as e:
|
| 386 |
+
logger.error(f"❌ 生成回答時出錯: {e}")
|
| 387 |
+
answer = f"生成回答時出錯: {e}"
|
| 388 |
+
answer_time = time.time() - answer_start
|
| 389 |
+
|
| 390 |
+
total_time = time.time() - start_time
|
| 391 |
+
|
| 392 |
+
return {
|
| 393 |
+
**retrieval_result,
|
| 394 |
+
"answer": answer,
|
| 395 |
+
"formatted_context": formatted_context,
|
| 396 |
+
"answer_time": answer_time,
|
| 397 |
+
"total_time": total_time
|
| 398 |
+
}
|
| 399 |
+
|
src/hyde_rag.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HyDE (Hypothetical Document Embeddings) RAG:使用假設性文檔改善檢索
|
| 3 |
+
"""
|
| 4 |
+
from typing import List, Dict, Optional
|
| 5 |
+
from .retrievers.reranker import RAGPipeline
|
| 6 |
+
from .retrievers.vector_retriever import VectorRetriever
|
| 7 |
+
from .prompt_formatter import PromptFormatter
|
| 8 |
+
from .llm_integration import OllamaLLM
|
| 9 |
+
import time
|
| 10 |
+
import logging
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class HyDERAG:
|
| 16 |
+
"""使用 HyDE (Hypothetical Document Embeddings) 的 RAG 系統"""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
rag_pipeline: RAGPipeline,
|
| 21 |
+
vector_retriever: VectorRetriever,
|
| 22 |
+
llm: OllamaLLM,
|
| 23 |
+
hypothetical_length: int = 200,
|
| 24 |
+
temperature: float = 0.7
|
| 25 |
+
):
|
| 26 |
+
"""
|
| 27 |
+
初始化 HyDE RAG
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
rag_pipeline: RAG 管線實例(用於最終答案生成)
|
| 31 |
+
vector_retriever: 向量檢索器(用於基於假設性文檔的檢索)
|
| 32 |
+
llm: LLM 實例(用於生成假設性文檔)
|
| 33 |
+
hypothetical_length: 假設性文檔的目標長度(字符數)
|
| 34 |
+
temperature: 生成假設性文檔時的溫度參數(建議 0.7,以獲得更多專業術語)
|
| 35 |
+
"""
|
| 36 |
+
self.rag_pipeline = rag_pipeline
|
| 37 |
+
self.vector_retriever = vector_retriever
|
| 38 |
+
self.llm = llm
|
| 39 |
+
self.hypothetical_length = hypothetical_length
|
| 40 |
+
self.temperature = temperature
|
| 41 |
+
|
| 42 |
+
def _generate_hypothetical_document(self, question: str) -> str:
|
| 43 |
+
"""
|
| 44 |
+
生成假設性文檔(Hypothetical Document)
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
question: 用戶問題
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
假設性文檔文本
|
| 51 |
+
"""
|
| 52 |
+
# 檢測語言
|
| 53 |
+
is_chinese = PromptFormatter.detect_language(question) == "zh"
|
| 54 |
+
|
| 55 |
+
if is_chinese:
|
| 56 |
+
prompt = f"""請針對以下問題,寫出一段約 {self.hypothetical_length} 字的專業技術檔案內容。
|
| 57 |
+
這段內容應包含該領域常見的專業術語與原理說明,以便用於後續的語義檢索。
|
| 58 |
+
請使用專業的術語和概念,即使你對某些細節不確定,也要包含相關的專業詞彙。
|
| 59 |
+
|
| 60 |
+
問題: {question}
|
| 61 |
+
|
| 62 |
+
專業技術內容:"""
|
| 63 |
+
else:
|
| 64 |
+
prompt = f"""Please write a professional technical document of approximately {self.hypothetical_length} words in response to the following question.
|
| 65 |
+
This content should include common professional terminology and principle explanations in this field, to be used for subsequent semantic retrieval.
|
| 66 |
+
Please use professional terms and concepts, and include relevant professional vocabulary even if you are uncertain about some details.
|
| 67 |
+
|
| 68 |
+
Question: {question}
|
| 69 |
+
|
| 70 |
+
Professional technical content:"""
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
hypothetical_doc = self.llm.generate(
|
| 74 |
+
prompt=prompt,
|
| 75 |
+
temperature=self.temperature, # 較高的溫度以獲得更多專業術語
|
| 76 |
+
max_tokens=500
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# 清理輸出
|
| 80 |
+
hypothetical_doc = hypothetical_doc.strip()
|
| 81 |
+
|
| 82 |
+
if not hypothetical_doc:
|
| 83 |
+
logger.warning("⚠️ 生成的假設性文檔為空,使用原始問題")
|
| 84 |
+
return question
|
| 85 |
+
|
| 86 |
+
logger.info(f"✅ 生成假設性文檔(長度: {len(hypothetical_doc)} 字符)")
|
| 87 |
+
return hypothetical_doc
|
| 88 |
+
|
| 89 |
+
except Exception as e:
|
| 90 |
+
logger.error(f"⚠️ 生成假設性文檔時出錯: {e}")
|
| 91 |
+
# 回退到使用原始問題
|
| 92 |
+
return question
|
| 93 |
+
|
| 94 |
+
def query(
|
| 95 |
+
self,
|
| 96 |
+
question: str,
|
| 97 |
+
top_k: int = 5,
|
| 98 |
+
metadata_filter: Optional[Dict] = None,
|
| 99 |
+
return_hypothetical: bool = False
|
| 100 |
+
) -> Dict:
|
| 101 |
+
"""
|
| 102 |
+
執行 HyDE 檢索(不生成答案)
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
question: 原始問題
|
| 106 |
+
top_k: 返回前 k 個結果
|
| 107 |
+
metadata_filter: 可選的 metadata 過濾條件
|
| 108 |
+
return_hypothetical: 是否在結果中包含假設性文檔
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
包含檢索結果和統計資訊的字典
|
| 112 |
+
"""
|
| 113 |
+
start_time = time.time()
|
| 114 |
+
|
| 115 |
+
# 第一步:生成假設性文檔
|
| 116 |
+
logger.info(f"🔍 生成假設性文檔: '{question}'")
|
| 117 |
+
hypothetical_doc = self._generate_hypothetical_document(question)
|
| 118 |
+
|
| 119 |
+
# 第二步:使用假設性文檔進行檢索
|
| 120 |
+
logger.info(f"📚 使用假設性文檔進行檢索...")
|
| 121 |
+
results = self.vector_retriever.retrieve(
|
| 122 |
+
query=hypothetical_doc, # 使用假設性文檔而不是原始問題
|
| 123 |
+
top_k=top_k,
|
| 124 |
+
metadata_filter=metadata_filter
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
elapsed_time = time.time() - start_time
|
| 128 |
+
logger.info(f"✅ 找到 {len(results)} 個結果(耗時: {elapsed_time:.2f}s)")
|
| 129 |
+
|
| 130 |
+
result = {
|
| 131 |
+
"results": results,
|
| 132 |
+
"total_docs_found": len(results),
|
| 133 |
+
"hypothetical_document": hypothetical_doc if return_hypothetical else None,
|
| 134 |
+
"elapsed_time": elapsed_time
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
return result
|
| 138 |
+
|
| 139 |
+
def generate_answer(
|
| 140 |
+
self,
|
| 141 |
+
question: str,
|
| 142 |
+
formatter: PromptFormatter,
|
| 143 |
+
top_k: int = 5,
|
| 144 |
+
metadata_filter: Optional[Dict] = None,
|
| 145 |
+
document_type: str = "general",
|
| 146 |
+
return_hypothetical: bool = False
|
| 147 |
+
) -> Dict:
|
| 148 |
+
"""
|
| 149 |
+
完整的 HyDE RAG 流程:生成假設性文檔 -> 檢索 -> 生成答案
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
question: 原始問題
|
| 153 |
+
formatter: Prompt 格式化器
|
| 154 |
+
top_k: 用於生成答案的文檔數量
|
| 155 |
+
metadata_filter: 可選的 metadata 過濾條件
|
| 156 |
+
document_type: 文檔類型 ("paper", "cv", "general")
|
| 157 |
+
return_hypothetical: 是否在結果中包含假設性文檔
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
包含檢索結果、生成的答案和統計資訊的字典
|
| 161 |
+
"""
|
| 162 |
+
start_time = time.time()
|
| 163 |
+
|
| 164 |
+
# 第一步:生成假設性文檔
|
| 165 |
+
logger.info(f"🔍 生成假設性文檔: '{question}'")
|
| 166 |
+
hypothetical_start = time.time()
|
| 167 |
+
hypothetical_doc = self._generate_hypothetical_document(question)
|
| 168 |
+
hypothetical_time = time.time() - hypothetical_start
|
| 169 |
+
|
| 170 |
+
# 第二步:使用假設性文檔進行檢索
|
| 171 |
+
logger.info(f"📚 使用假設性文檔進行檢索...")
|
| 172 |
+
retrieval_start = time.time()
|
| 173 |
+
results = self.vector_retriever.retrieve(
|
| 174 |
+
query=hypothetical_doc, # 使用假設性文檔而不是原始問題
|
| 175 |
+
top_k=top_k,
|
| 176 |
+
metadata_filter=metadata_filter
|
| 177 |
+
)
|
| 178 |
+
retrieval_time = time.time() - retrieval_start
|
| 179 |
+
|
| 180 |
+
if not results:
|
| 181 |
+
return {
|
| 182 |
+
"results": [],
|
| 183 |
+
"total_docs_found": 0,
|
| 184 |
+
"hypothetical_document": hypothetical_doc if return_hypothetical else None,
|
| 185 |
+
"elapsed_time": retrieval_time + hypothetical_time,
|
| 186 |
+
"answer": "抱歉,未找到相關文檔來回答此問題。",
|
| 187 |
+
"formatted_context": None,
|
| 188 |
+
"answer_time": 0.0,
|
| 189 |
+
"total_time": retrieval_time + hypothetical_time
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
# 第三步:格式化上下文
|
| 193 |
+
formatted_context = formatter.format_context(
|
| 194 |
+
results,
|
| 195 |
+
document_type=document_type
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# 第四步:創建 prompt(使用原始問題,而不是假設性文檔)
|
| 199 |
+
prompt = formatter.create_prompt(
|
| 200 |
+
question, # 使用原始問題生成答案
|
| 201 |
+
formatted_context,
|
| 202 |
+
document_type=document_type
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# 第五步:生成回答
|
| 206 |
+
logger.info("🤖 生成回答中...")
|
| 207 |
+
answer_start = time.time()
|
| 208 |
+
try:
|
| 209 |
+
answer = self.llm.generate(
|
| 210 |
+
prompt=prompt,
|
| 211 |
+
temperature=0.7,
|
| 212 |
+
max_tokens=2048
|
| 213 |
+
)
|
| 214 |
+
answer_time = time.time() - answer_start
|
| 215 |
+
logger.info(f"✅ 回答生成完成(耗時: {answer_time:.2f}s)")
|
| 216 |
+
except Exception as e:
|
| 217 |
+
logger.error(f"❌ 生成回答時出錯: {e}")
|
| 218 |
+
answer = f"生成回答時出錯: {e}"
|
| 219 |
+
answer_time = time.time() - answer_start
|
| 220 |
+
|
| 221 |
+
total_time = time.time() - start_time
|
| 222 |
+
|
| 223 |
+
return {
|
| 224 |
+
"results": results,
|
| 225 |
+
"total_docs_found": len(results),
|
| 226 |
+
"hypothetical_document": hypothetical_doc if return_hypothetical else None,
|
| 227 |
+
"elapsed_time": retrieval_time + hypothetical_time,
|
| 228 |
+
"hypothetical_time": hypothetical_time,
|
| 229 |
+
"retrieval_time": retrieval_time,
|
| 230 |
+
"answer": answer,
|
| 231 |
+
"formatted_context": formatted_context,
|
| 232 |
+
"answer_time": answer_time,
|
| 233 |
+
"total_time": total_time
|
| 234 |
+
}
|
| 235 |
+
|
src/llm_integration.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLM 集成模組:使用 Ollama 進行本地 LLM 推理
|
| 3 |
+
"""
|
| 4 |
+
from typing import Optional, Dict, List
|
| 5 |
+
import logging
|
| 6 |
+
import requests
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class OllamaLLM:
|
| 13 |
+
"""使用 Ollama 進行本地 LLM 推理"""
|
| 14 |
+
|
| 15 |
+
# 適合 16GB MacBook Air 的模型推薦
|
| 16 |
+
RECOMMENDED_MODELS = {
|
| 17 |
+
"deepseek-r1:7b": {
|
| 18 |
+
"name": "deepseek-r1:7b",
|
| 19 |
+
"description": "DeepSeek R1 7B - 大模型,高質量",
|
| 20 |
+
"memory_required": "~8GB",
|
| 21 |
+
"quality": "優秀"
|
| 22 |
+
},
|
| 23 |
+
"llama3.2:3b": {
|
| 24 |
+
"name": "llama3.2:3b",
|
| 25 |
+
"description": "Meta Llama 3.2 3B - 輕量級,適合 16GB 內存",
|
| 26 |
+
"memory_required": "~4GB",
|
| 27 |
+
"quality": "良好"
|
| 28 |
+
},
|
| 29 |
+
"llama3.2:1b": {
|
| 30 |
+
"name": "llama3.2:1b",
|
| 31 |
+
"description": "Meta Llama 3.2 1B - 極輕量級,快速響應",
|
| 32 |
+
"memory_required": "~2GB",
|
| 33 |
+
"quality": "基礎"
|
| 34 |
+
},
|
| 35 |
+
"phi3:mini": {
|
| 36 |
+
"name": "phi3:mini",
|
| 37 |
+
"description": "Microsoft Phi-3 Mini - 小模型,高質量",
|
| 38 |
+
"memory_required": "~3GB",
|
| 39 |
+
"quality": "良好"
|
| 40 |
+
},
|
| 41 |
+
"gemma:2b": {
|
| 42 |
+
"name": "gemma:2b",
|
| 43 |
+
"description": "Google Gemma 2B - 輕量級,開源",
|
| 44 |
+
"memory_required": "~3GB",
|
| 45 |
+
"quality": "良好"
|
| 46 |
+
},
|
| 47 |
+
"mistral:7b": {
|
| 48 |
+
"name": "mistral:7b",
|
| 49 |
+
"description": "Mistral 7B - 較大但質量高(如果內存足夠)",
|
| 50 |
+
"memory_required": "~8GB",
|
| 51 |
+
"quality": "優秀"
|
| 52 |
+
}
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
model_name: str = "llama3.2:3b",
|
| 58 |
+
base_url: str = "http://localhost:11434",
|
| 59 |
+
timeout: int = 120
|
| 60 |
+
):
|
| 61 |
+
"""
|
| 62 |
+
初始化 Ollama LLM
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
model_name: Ollama 模型名稱(預設: llama3.2:3b)
|
| 66 |
+
base_url: Ollama API 基礎 URL
|
| 67 |
+
timeout: 請求超時時間(秒)
|
| 68 |
+
"""
|
| 69 |
+
self.model_name = model_name
|
| 70 |
+
self.base_url = base_url.rstrip('/')
|
| 71 |
+
self.timeout = timeout
|
| 72 |
+
self.api_url = f"{self.base_url}/api"
|
| 73 |
+
|
| 74 |
+
# 檢查模型是否在推薦列表中
|
| 75 |
+
if model_name not in self.RECOMMENDED_MODELS:
|
| 76 |
+
logger.warning(
|
| 77 |
+
f"⚠️ 模型 '{model_name}' 不在推薦列表中。"
|
| 78 |
+
f"推薦的模型: {', '.join(self.RECOMMENDED_MODELS.keys())}"
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
logger.info(f"✅ Ollama LLM 初始化完成 (模型: {model_name})")
|
| 82 |
+
|
| 83 |
+
def _check_ollama_connection(self) -> bool:
|
| 84 |
+
"""
|
| 85 |
+
檢查 Ollama 服務是否可用
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
是否連接成功
|
| 89 |
+
"""
|
| 90 |
+
try:
|
| 91 |
+
response = requests.get(f"{self.base_url}/api/tags", timeout=5)
|
| 92 |
+
return response.status_code == 200
|
| 93 |
+
except Exception as e:
|
| 94 |
+
logger.error(f"❌ 無法連接到 Ollama: {e}")
|
| 95 |
+
logger.error(f" 請確保 Ollama 正在運行: ollama serve")
|
| 96 |
+
return False
|
| 97 |
+
|
| 98 |
+
def _check_model_available(self) -> bool:
|
| 99 |
+
"""
|
| 100 |
+
檢查模型是否已下載
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
模型是否可用
|
| 104 |
+
"""
|
| 105 |
+
try:
|
| 106 |
+
response = requests.get(f"{self.base_url}/api/tags", timeout=5)
|
| 107 |
+
if response.status_code == 200:
|
| 108 |
+
models = response.json().get('models', [])
|
| 109 |
+
model_names = [m.get('name', '') for m in models]
|
| 110 |
+
return any(self.model_name in name for name in model_names)
|
| 111 |
+
return False
|
| 112 |
+
except Exception as e:
|
| 113 |
+
logger.error(f"❌ 檢查模型時出錯: {e}")
|
| 114 |
+
return False
|
| 115 |
+
|
| 116 |
+
def generate(
|
| 117 |
+
self,
|
| 118 |
+
prompt: str,
|
| 119 |
+
temperature: float = 0.7,
|
| 120 |
+
max_tokens: Optional[int] = None,
|
| 121 |
+
stream: bool = False
|
| 122 |
+
) -> str:
|
| 123 |
+
"""
|
| 124 |
+
生成回答
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
prompt: 輸入 prompt
|
| 128 |
+
temperature: 溫度參數(0.0-1.0),控制隨機性
|
| 129 |
+
max_tokens: 最大生成 token 數(None 表示使用模型預設)
|
| 130 |
+
stream: 是否使用流式輸出
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
生成的回答
|
| 134 |
+
"""
|
| 135 |
+
# 檢查連接
|
| 136 |
+
if not self._check_ollama_connection():
|
| 137 |
+
raise ConnectionError(
|
| 138 |
+
f"無法連接到 Ollama 服務 ({self.base_url})\n"
|
| 139 |
+
f"請確保 Ollama 正在運行:\n"
|
| 140 |
+
f" 1. 安裝 Ollama: https://ollama.ai\n"
|
| 141 |
+
f" 2. 啟動服務: ollama serve\n"
|
| 142 |
+
f" 3. 下載模型: ollama pull {self.model_name}"
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# 檢查模型
|
| 146 |
+
if not self._check_model_available():
|
| 147 |
+
logger.warning(
|
| 148 |
+
f"⚠️ 模型 '{self.model_name}' 可能未下載。"
|
| 149 |
+
f"請運行: ollama pull {self.model_name}"
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# 準備請求參數
|
| 153 |
+
payload = {
|
| 154 |
+
"model": self.model_name,
|
| 155 |
+
"prompt": prompt,
|
| 156 |
+
"stream": stream,
|
| 157 |
+
"options": {
|
| 158 |
+
"temperature": temperature,
|
| 159 |
+
}
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
if max_tokens:
|
| 163 |
+
payload["options"]["num_predict"] = max_tokens
|
| 164 |
+
|
| 165 |
+
try:
|
| 166 |
+
# 發送請求
|
| 167 |
+
response = requests.post(
|
| 168 |
+
f"{self.api_url}/generate",
|
| 169 |
+
json=payload,
|
| 170 |
+
timeout=self.timeout,
|
| 171 |
+
stream=stream
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
if response.status_code != 200:
|
| 175 |
+
error_msg = response.text
|
| 176 |
+
raise RuntimeError(f"Ollama API 錯誤: {error_msg}")
|
| 177 |
+
|
| 178 |
+
if stream:
|
| 179 |
+
# 流式處理
|
| 180 |
+
full_response = ""
|
| 181 |
+
for line in response.iter_lines():
|
| 182 |
+
if line:
|
| 183 |
+
try:
|
| 184 |
+
data = json.loads(line)
|
| 185 |
+
if 'response' in data:
|
| 186 |
+
chunk = data['response']
|
| 187 |
+
full_response += chunk
|
| 188 |
+
print(chunk, end='', flush=True)
|
| 189 |
+
if data.get('done', False):
|
| 190 |
+
break
|
| 191 |
+
except json.JSONDecodeError:
|
| 192 |
+
continue
|
| 193 |
+
print() # 換行
|
| 194 |
+
return full_response
|
| 195 |
+
else:
|
| 196 |
+
# 非流式處理
|
| 197 |
+
data = response.json()
|
| 198 |
+
return data.get('response', '')
|
| 199 |
+
|
| 200 |
+
except requests.exceptions.Timeout:
|
| 201 |
+
raise TimeoutError(
|
| 202 |
+
f"請求超時({self.timeout}秒)。"
|
| 203 |
+
f"可以嘗試增加 timeout 或使用更小的模型。"
|
| 204 |
+
)
|
| 205 |
+
except requests.exceptions.ConnectionError:
|
| 206 |
+
raise ConnectionError(
|
| 207 |
+
f"無法連接到 Ollama 服務。"
|
| 208 |
+
f"請確保 Ollama 正在運行:ollama serve"
|
| 209 |
+
)
|
| 210 |
+
except Exception as e:
|
| 211 |
+
logger.error(f"❌ 生成回答時出錯: {e}")
|
| 212 |
+
raise
|
| 213 |
+
|
| 214 |
+
def list_available_models(self) -> List[str]:
|
| 215 |
+
"""
|
| 216 |
+
列出本地可用的模型
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
可用模型名稱列表
|
| 220 |
+
"""
|
| 221 |
+
try:
|
| 222 |
+
response = requests.get(f"{self.base_url}/api/tags", timeout=5)
|
| 223 |
+
if response.status_code == 200:
|
| 224 |
+
models = response.json().get('models', [])
|
| 225 |
+
return [m.get('name', '') for m in models]
|
| 226 |
+
return []
|
| 227 |
+
except Exception as e:
|
| 228 |
+
logger.error(f"❌ 獲取模型列表時出錯: {e}")
|
| 229 |
+
return []
|
| 230 |
+
|
| 231 |
+
@classmethod
|
| 232 |
+
def print_recommended_models(cls):
|
| 233 |
+
"""打印推薦的模型列表"""
|
| 234 |
+
print("\n" + "="*60)
|
| 235 |
+
print("適合 16GB MacBook Air 的 Ollama 模型推薦")
|
| 236 |
+
print("="*60)
|
| 237 |
+
print()
|
| 238 |
+
|
| 239 |
+
for model_key, info in cls.RECOMMENDED_MODELS.items():
|
| 240 |
+
print(f"📦 {info['name']}")
|
| 241 |
+
print(f" 描述: {info['description']}")
|
| 242 |
+
print(f" 內存需求: {info['memory_required']}")
|
| 243 |
+
print(f" 質量: {info['quality']}")
|
| 244 |
+
print(f" 下載命令: ollama pull {info['name']}")
|
| 245 |
+
print()
|
| 246 |
+
|
src/prompt_formatter.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Prompt 格式化模組:將檢索結果格式化為 LLM 可讀的上下文
|
| 3 |
+
"""
|
| 4 |
+
from typing import List, Dict, Optional
|
| 5 |
+
import re
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class PromptFormatter:
|
| 9 |
+
"""格式化檢索結果供 LLM 使用"""
|
| 10 |
+
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
include_metadata: bool = True,
|
| 14 |
+
format_style: str = "detailed",
|
| 15 |
+
max_context_length: Optional[int] = None,
|
| 16 |
+
auto_detect_language: bool = True
|
| 17 |
+
):
|
| 18 |
+
"""
|
| 19 |
+
初始化 Prompt 格式化器
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
include_metadata: 是否包含來源資訊
|
| 23 |
+
format_style: 格式風格 ("detailed", "simple", "minimal")
|
| 24 |
+
max_context_length: 最大上下文長度(字符數),None 表示不限制
|
| 25 |
+
auto_detect_language: 是否自動檢測語言並相應調整回答語言
|
| 26 |
+
"""
|
| 27 |
+
self.include_metadata = include_metadata
|
| 28 |
+
self.format_style = format_style
|
| 29 |
+
self.max_context_length = max_context_length
|
| 30 |
+
self.auto_detect_language = auto_detect_language
|
| 31 |
+
|
| 32 |
+
@staticmethod
|
| 33 |
+
def detect_language(text: str) -> str:
|
| 34 |
+
"""
|
| 35 |
+
檢測文本的主要語言
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
text: 輸入文本
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
"zh" 表示中文,"en" 表示英文
|
| 42 |
+
"""
|
| 43 |
+
# 檢查是否包含中文字符(CJK 統一表意文字範圍)
|
| 44 |
+
chinese_pattern = re.compile(r'[\u4e00-\u9fff\u3400-\u4dbf\uf900-\ufaff]')
|
| 45 |
+
chinese_chars = len(chinese_pattern.findall(text))
|
| 46 |
+
|
| 47 |
+
# 計算中文字符比例
|
| 48 |
+
total_chars = len([c for c in text if c.isalnum() or c.isspace()])
|
| 49 |
+
|
| 50 |
+
if total_chars == 0:
|
| 51 |
+
return "en" # 預設英文
|
| 52 |
+
|
| 53 |
+
chinese_ratio = chinese_chars / total_chars if total_chars > 0 else 0
|
| 54 |
+
|
| 55 |
+
# 如果中文字符比例超過 20%,認為是中文
|
| 56 |
+
if chinese_ratio > 0.2:
|
| 57 |
+
return "zh"
|
| 58 |
+
else:
|
| 59 |
+
return "en"
|
| 60 |
+
|
| 61 |
+
def get_system_prompt(self, language: str = "zh", document_type: str = "general") -> str:
|
| 62 |
+
"""
|
| 63 |
+
根據語言和文檔類型獲取系統提示詞
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
language: 語言代碼 ("zh" 或 "en")
|
| 67 |
+
document_type: 文檔類型 ("paper", "cv", "general")
|
| 68 |
+
"paper": 學術論文
|
| 69 |
+
"cv": 履歷/履歷
|
| 70 |
+
"general": 通用文檔(預設)
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
系統提示詞字符串
|
| 74 |
+
"""
|
| 75 |
+
if language == "zh":
|
| 76 |
+
if document_type == "paper":
|
| 77 |
+
return (
|
| 78 |
+
"你是一個專業的 AI 研究助手,專門回答關於機器學習、"
|
| 79 |
+
"深度學習和自然語言處理的問題。\n\n"
|
| 80 |
+
"請基於以下提供的學術論文片段來回答用戶的問題。"
|
| 81 |
+
"每個片段都標註了來源論文的資訊。\n\n"
|
| 82 |
+
"回答要求:\n"
|
| 83 |
+
"1. 基於提供的上下文回答問題\n"
|
| 84 |
+
"2. 如果上下文不足以回答,請明確說明\n"
|
| 85 |
+
"3. 在回答中引用具體的論文來源(使用 arXiv ID)\n"
|
| 86 |
+
"4. 如果不同論文有不同觀點,請分別說明\n"
|
| 87 |
+
"5. 保持回答簡潔、準確、專業\n"
|
| 88 |
+
"6. **重要:請使用與用戶問題相同的語言回答**\n"
|
| 89 |
+
)
|
| 90 |
+
elif document_type == "cv":
|
| 91 |
+
return (
|
| 92 |
+
"你是一個專業的 AI 助手,專門幫助分析和介紹簡歷(CV)內容。\n\n"
|
| 93 |
+
"請基於以下提供的文檔片段來回答用戶的問題。"
|
| 94 |
+
"這些片段來自一份簡歷或履歷表。\n\n"
|
| 95 |
+
"回答要求:\n"
|
| 96 |
+
"1. 基於提供的上下文回答問題\n"
|
| 97 |
+
"2. 如果上下文不足以回答,請明確說明\n"
|
| 98 |
+
"3. 在回答中引用具體的文檔內容\n"
|
| 99 |
+
"4. 保持回答簡潔、準確、專業\n"
|
| 100 |
+
"5. **重要:請使用與用戶問題相同的語言回答**\n"
|
| 101 |
+
"6. **請理解:這些片段就是簡歷的內容,請直接基於這些內容回答問題**\n"
|
| 102 |
+
)
|
| 103 |
+
else: # general
|
| 104 |
+
return (
|
| 105 |
+
"你是一個專業的 AI 助手。\n\n"
|
| 106 |
+
"請基於以下提供的文檔片段來回答用戶的問題。"
|
| 107 |
+
"每個片段都標註了來源資訊。\n\n"
|
| 108 |
+
"回答要求:\n"
|
| 109 |
+
"1. 基於提供的上下文回答問題\n"
|
| 110 |
+
"2. 如果上下文不足以回答,請明確說明\n"
|
| 111 |
+
"3. 在回答中引用具體的文檔內容\n"
|
| 112 |
+
"4. 保持回答簡潔、準確、專業\n"
|
| 113 |
+
"5. **重要:請使用與用戶問題相同的語言回答**\n"
|
| 114 |
+
)
|
| 115 |
+
else: # English
|
| 116 |
+
if document_type == "paper":
|
| 117 |
+
return (
|
| 118 |
+
"You are a professional AI research assistant specializing in "
|
| 119 |
+
"machine learning, deep learning, and natural language processing.\n\n"
|
| 120 |
+
"Please answer the user's question based on the provided academic paper excerpts. "
|
| 121 |
+
"Each excerpt is labeled with source paper information.\n\n"
|
| 122 |
+
"Answer requirements:\n"
|
| 123 |
+
"1. Answer the question based on the provided context\n"
|
| 124 |
+
"2. If the context is insufficient, clearly state so\n"
|
| 125 |
+
"3. Cite specific paper sources in your answer (using arXiv ID)\n"
|
| 126 |
+
"4. If different papers have different viewpoints, explain them separately\n"
|
| 127 |
+
"5. Keep answers concise, accurate, and professional\n"
|
| 128 |
+
"6. **Important: Please answer in the same language as the user's question**\n"
|
| 129 |
+
)
|
| 130 |
+
elif document_type == "cv":
|
| 131 |
+
return (
|
| 132 |
+
"You are a professional AI assistant specializing in analyzing and introducing CV (Curriculum Vitae) content.\n\n"
|
| 133 |
+
"Please answer the user's question based on the provided document excerpts. "
|
| 134 |
+
"These excerpts are from a CV or resume.\n\n"
|
| 135 |
+
"Answer requirements:\n"
|
| 136 |
+
"1. Answer the question based on the provided context\n"
|
| 137 |
+
"2. If the context is insufficient, clearly state so\n"
|
| 138 |
+
"3. Cite specific document content in your answer\n"
|
| 139 |
+
"4. Keep answers concise, accurate, and professional\n"
|
| 140 |
+
"5. **Important: Please answer in the same language as the user's question**\n"
|
| 141 |
+
"6. **Please understand: These excerpts ARE the CV content. Please answer directly based on this content.**\n"
|
| 142 |
+
)
|
| 143 |
+
else: # general
|
| 144 |
+
return (
|
| 145 |
+
"You are a professional AI assistant.\n\n"
|
| 146 |
+
"Please answer the user's question based on the provided document excerpts. "
|
| 147 |
+
"Each excerpt is labeled with source information.\n\n"
|
| 148 |
+
"Answer requirements:\n"
|
| 149 |
+
"1. Answer the question based on the provided context\n"
|
| 150 |
+
"2. If the context is insufficient, clearly state so\n"
|
| 151 |
+
"3. Cite specific document content in your answer\n"
|
| 152 |
+
"4. Keep answers concise, accurate, and professional\n"
|
| 153 |
+
"5. **Important: Please answer in the same language as the user's question**\n"
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
def format_context(
|
| 157 |
+
self,
|
| 158 |
+
results: List[Dict],
|
| 159 |
+
include_metadata: Optional[bool] = None,
|
| 160 |
+
format_style: Optional[str] = None,
|
| 161 |
+
document_type: str = "general"
|
| 162 |
+
) -> str:
|
| 163 |
+
"""
|
| 164 |
+
格式化檢索結果為 LLM 可讀的上下文
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
results: 檢索結果列表
|
| 168 |
+
include_metadata: 是否包含來源資訊(覆蓋初始化參數)
|
| 169 |
+
format_style: 格式風格(覆蓋初始化參數)
|
| 170 |
+
document_type: 文檔類型 ("paper", "cv", "general"),用於調整格式
|
| 171 |
+
|
| 172 |
+
Returns:
|
| 173 |
+
格式化後的上下文字符串
|
| 174 |
+
"""
|
| 175 |
+
if include_metadata is None:
|
| 176 |
+
include_metadata = self.include_metadata
|
| 177 |
+
if format_style is None:
|
| 178 |
+
format_style = self.format_style
|
| 179 |
+
|
| 180 |
+
if not results:
|
| 181 |
+
# 根據格式風格選擇語言
|
| 182 |
+
if format_style == "detailed" or format_style == "simple":
|
| 183 |
+
return "(未找到相關文檔片段)"
|
| 184 |
+
else:
|
| 185 |
+
return "(No relevant excerpts found)"
|
| 186 |
+
|
| 187 |
+
formatted_parts = []
|
| 188 |
+
|
| 189 |
+
for i, result in enumerate(results, 1):
|
| 190 |
+
content = result.get("content", "")
|
| 191 |
+
metadata = result.get("metadata", {})
|
| 192 |
+
|
| 193 |
+
if not include_metadata:
|
| 194 |
+
# 不包含來源資訊,直接使用內容
|
| 195 |
+
formatted_parts.append(f"{content}\n")
|
| 196 |
+
elif format_style == "detailed":
|
| 197 |
+
# 詳細格式:根據文檔類型調整顯示資訊
|
| 198 |
+
if document_type == "cv":
|
| 199 |
+
# CV 格式:顯示檔案名和路徑
|
| 200 |
+
source_info = (
|
| 201 |
+
f"[來源 {i}]\n"
|
| 202 |
+
f"檔案標題: {metadata.get('title', 'N/A')}\n"
|
| 203 |
+
)
|
| 204 |
+
if 'file_path' in metadata:
|
| 205 |
+
source_info += f"檔案路徑: {metadata.get('file_path', 'N/A')}\n"
|
| 206 |
+
if 'file_type' in metadata:
|
| 207 |
+
source_info += f"檔案類型: {metadata.get('file_type', 'N/A')}\n"
|
| 208 |
+
elif document_type == "paper":
|
| 209 |
+
# 論文格式:顯示論文資訊
|
| 210 |
+
authors = metadata.get('authors', [])
|
| 211 |
+
if isinstance(authors, str):
|
| 212 |
+
authors_str = authors
|
| 213 |
+
elif isinstance(authors, list):
|
| 214 |
+
authors_str = ', '.join(authors[:3]) # 最多顯示 3 個作者
|
| 215 |
+
if len(authors) > 3:
|
| 216 |
+
authors_str += f" 等 {len(authors)} 位作者"
|
| 217 |
+
else:
|
| 218 |
+
authors_str = 'N/A'
|
| 219 |
+
|
| 220 |
+
source_info = (
|
| 221 |
+
f"[來源 {i}]\n"
|
| 222 |
+
f"論文標題: {metadata.get('title', 'N/A')}\n"
|
| 223 |
+
f"arXiv ID: {metadata.get('arxiv_id', 'N/A')}\n"
|
| 224 |
+
f"作者: {authors_str}\n"
|
| 225 |
+
f"發布日期: {metadata.get('published', 'N/A')}\n"
|
| 226 |
+
)
|
| 227 |
+
else:
|
| 228 |
+
# 通用格式:顯示可用的資訊
|
| 229 |
+
source_info = f"[來源 {i}]\n"
|
| 230 |
+
if 'title' in metadata:
|
| 231 |
+
source_info += f"標題: {metadata.get('title', 'N/A')}\n"
|
| 232 |
+
if 'file_path' in metadata:
|
| 233 |
+
source_info += f"檔案: {metadata.get('file_path', 'N/A')}\n"
|
| 234 |
+
if 'arxiv_id' in metadata:
|
| 235 |
+
source_info += f"arXiv ID: {metadata.get('arxiv_id', 'N/A')}\n"
|
| 236 |
+
|
| 237 |
+
# 添加相關性分數(如果有的話)
|
| 238 |
+
rerank_score = result.get('rerank_score')
|
| 239 |
+
hybrid_score = result.get('hybrid_score')
|
| 240 |
+
if rerank_score is not None:
|
| 241 |
+
source_info += f"相關性分數: {rerank_score:.4f}\n"
|
| 242 |
+
elif hybrid_score is not None:
|
| 243 |
+
source_info += f"相關性分數: {hybrid_score:.4f}\n"
|
| 244 |
+
|
| 245 |
+
source_info += f"---\n{content}\n"
|
| 246 |
+
formatted_parts.append(source_info)
|
| 247 |
+
|
| 248 |
+
elif format_style == "simple":
|
| 249 |
+
# 簡單格式:只包含關鍵資訊
|
| 250 |
+
title = metadata.get('title', 'N/A')
|
| 251 |
+
if document_type == "paper" and 'arxiv_id' in metadata:
|
| 252 |
+
arxiv_id = metadata.get('arxiv_id', 'N/A')
|
| 253 |
+
source_info = (
|
| 254 |
+
f"[來源 {i}: {title} "
|
| 255 |
+
f"(arXiv:{arxiv_id})]\n"
|
| 256 |
+
f"{content}\n"
|
| 257 |
+
)
|
| 258 |
+
elif document_type == "cv" and 'file_path' in metadata:
|
| 259 |
+
file_path = metadata.get('file_path', 'N/A')
|
| 260 |
+
source_info = (
|
| 261 |
+
f"[來源 {i}: {title} "
|
| 262 |
+
f"({file_path})]\n"
|
| 263 |
+
f"{content}\n"
|
| 264 |
+
)
|
| 265 |
+
else:
|
| 266 |
+
source_info = (
|
| 267 |
+
f"[來源 {i}: {title}]\n"
|
| 268 |
+
f"{content}\n"
|
| 269 |
+
)
|
| 270 |
+
formatted_parts.append(source_info)
|
| 271 |
+
else: # minimal
|
| 272 |
+
# 最小格式:只標註來源
|
| 273 |
+
if document_type == "paper" and 'arxiv_id' in metadata:
|
| 274 |
+
arxiv_id = metadata.get('arxiv_id', 'N/A')
|
| 275 |
+
source_info = (
|
| 276 |
+
f"[arXiv:{arxiv_id}]\n"
|
| 277 |
+
f"{content}\n"
|
| 278 |
+
)
|
| 279 |
+
elif 'title' in metadata:
|
| 280 |
+
title = metadata.get('title', 'N/A')
|
| 281 |
+
source_info = (
|
| 282 |
+
f"[{title}]\n"
|
| 283 |
+
f"{content}\n"
|
| 284 |
+
)
|
| 285 |
+
else:
|
| 286 |
+
source_info = (
|
| 287 |
+
f"[來源 {i}]\n"
|
| 288 |
+
f"{content}\n"
|
| 289 |
+
)
|
| 290 |
+
formatted_parts.append(source_info)
|
| 291 |
+
|
| 292 |
+
formatted_text = "\n" + "="*60 + "\n".join(formatted_parts)
|
| 293 |
+
|
| 294 |
+
# 如果設置了最大長度,進行截斷
|
| 295 |
+
if self.max_context_length and len(formatted_text) > self.max_context_length:
|
| 296 |
+
# 從後往前截斷,保留格式
|
| 297 |
+
formatted_text = formatted_text[:self.max_context_length]
|
| 298 |
+
# 確保最後一個來源資訊完整
|
| 299 |
+
last_separator = formatted_text.rfind("="*60)
|
| 300 |
+
if last_separator > 0:
|
| 301 |
+
formatted_text = formatted_text[:last_separator] + "\n(內容已截斷...)"
|
| 302 |
+
|
| 303 |
+
return formatted_text
|
| 304 |
+
|
| 305 |
+
def create_prompt(
|
| 306 |
+
self,
|
| 307 |
+
query: str,
|
| 308 |
+
context: str,
|
| 309 |
+
system_prompt: Optional[str] = None,
|
| 310 |
+
document_type: str = "general"
|
| 311 |
+
) -> str:
|
| 312 |
+
"""
|
| 313 |
+
創建完整的 LLM prompt
|
| 314 |
+
|
| 315 |
+
Args:
|
| 316 |
+
query: 用戶查詢
|
| 317 |
+
context: 格式化後的上下文
|
| 318 |
+
system_prompt: 可選的系統提示詞(如果為 None,會根據語言和文檔類型自動選擇)
|
| 319 |
+
document_type: 文檔類型 ("paper", "cv", "general")
|
| 320 |
+
|
| 321 |
+
Returns:
|
| 322 |
+
完整的 prompt 字符串
|
| 323 |
+
"""
|
| 324 |
+
# 自動檢測語言並選擇相應的系統提示詞
|
| 325 |
+
if system_prompt is None and self.auto_detect_language:
|
| 326 |
+
detected_language = self.detect_language(query)
|
| 327 |
+
system_prompt = self.get_system_prompt(detected_language, document_type)
|
| 328 |
+
elif system_prompt is None:
|
| 329 |
+
# 如果禁用自動檢測,使用中文作為預設
|
| 330 |
+
system_prompt = self.get_system_prompt("zh", document_type)
|
| 331 |
+
|
| 332 |
+
# 根據檢測到的語言選擇提示詞格式
|
| 333 |
+
detected_language = self.detect_language(query) if self.auto_detect_language else "zh"
|
| 334 |
+
|
| 335 |
+
# 根據文檔類型選擇不同的提示詞結尾
|
| 336 |
+
if document_type == "paper":
|
| 337 |
+
if detected_language == "zh":
|
| 338 |
+
ending = "## 請基於上述文獻片段回答問題,並在回答中引用具體的論文來源。"
|
| 339 |
+
else:
|
| 340 |
+
ending = "## Please answer the question based on the above document excerpts and cite specific paper sources in your answer."
|
| 341 |
+
else:
|
| 342 |
+
if detected_language == "zh":
|
| 343 |
+
ending = "## 請基於上述文檔片段回答問題,並在回答中引用具體的文檔內容。"
|
| 344 |
+
else:
|
| 345 |
+
ending = "## Please answer the question based on the above document excerpts and cite specific document content in your answer."
|
| 346 |
+
|
| 347 |
+
if detected_language == "zh":
|
| 348 |
+
prompt = f"""{system_prompt}
|
| 349 |
+
|
| 350 |
+
## 相關文檔片段:
|
| 351 |
+
|
| 352 |
+
{context}
|
| 353 |
+
|
| 354 |
+
## 用戶問題:
|
| 355 |
+
|
| 356 |
+
{query}
|
| 357 |
+
|
| 358 |
+
{ending}"""
|
| 359 |
+
else: # English
|
| 360 |
+
prompt = f"""{system_prompt}
|
| 361 |
+
|
| 362 |
+
## Relevant Document Excerpts:
|
| 363 |
+
|
| 364 |
+
{context}
|
| 365 |
+
|
| 366 |
+
## User Question:
|
| 367 |
+
|
| 368 |
+
{query}
|
| 369 |
+
|
| 370 |
+
{ending}"""
|
| 371 |
+
|
| 372 |
+
return prompt
|
| 373 |
+
|
| 374 |
+
def format_for_llm(
|
| 375 |
+
self,
|
| 376 |
+
query: str,
|
| 377 |
+
results: List[Dict],
|
| 378 |
+
system_prompt: Optional[str] = None,
|
| 379 |
+
document_type: str = "general"
|
| 380 |
+
) -> str:
|
| 381 |
+
"""
|
| 382 |
+
一站式方法:格式化檢索結果並創建完整的 prompt
|
| 383 |
+
|
| 384 |
+
Args:
|
| 385 |
+
query: 用戶查詢
|
| 386 |
+
results: 檢索結果列表
|
| 387 |
+
system_prompt: 可選的系統提示詞
|
| 388 |
+
document_type: 文檔類型 ("paper", "cv", "general")
|
| 389 |
+
|
| 390 |
+
Returns:
|
| 391 |
+
完整的 prompt 字符串
|
| 392 |
+
"""
|
| 393 |
+
context = self.format_context(results, document_type=document_type)
|
| 394 |
+
return self.create_prompt(query, context, system_prompt, document_type)
|
| 395 |
+
|
src/retrievers/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
檢索器模組
|
| 3 |
+
"""
|
| 4 |
+
from .base import BaseRetriever
|
| 5 |
+
from .bm25_retriever import BM25Retriever
|
| 6 |
+
from .vector_retriever import VectorRetriever
|
| 7 |
+
from .hybrid_search import HybridSearch
|
| 8 |
+
from .reranker import Reranker, RAGPipeline
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"BaseRetriever",
|
| 12 |
+
"BM25Retriever",
|
| 13 |
+
"VectorRetriever",
|
| 14 |
+
"HybridSearch",
|
| 15 |
+
"Reranker",
|
| 16 |
+
"RAGPipeline",
|
| 17 |
+
]
|
src/retrievers/base.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
檢索器模組的抽象基類
|
| 3 |
+
"""
|
| 4 |
+
from abc import ABC, abstractmethod
|
| 5 |
+
from typing import List, Dict, Optional
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class BaseRetriever(ABC):
|
| 9 |
+
"""檢索器的抽象基類"""
|
| 10 |
+
|
| 11 |
+
@abstractmethod
|
| 12 |
+
def retrieve(
|
| 13 |
+
self,
|
| 14 |
+
query: str,
|
| 15 |
+
top_k: int = 5,
|
| 16 |
+
metadata_filter: Optional[Dict] = None
|
| 17 |
+
) -> List[Dict]:
|
| 18 |
+
"""
|
| 19 |
+
檢索相關文檔並返回帶有分數的結果。
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
query: 查詢文字
|
| 23 |
+
top_k: 返回前 k 個結果
|
| 24 |
+
metadata_filter: 可選的 metadata 過濾條件字典。
|
| 25 |
+
例如: {"arxiv_id": "1234.5678"} 或 {"title": "Machine Learning"}
|
| 26 |
+
支援多個條件,所有條件必須同時滿足(AND 邏輯)
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
相關文檔列表,每個文檔字典都應包含 "score" 鍵,
|
| 30 |
+
且分數越高代表越相關。返回的結果會根據 metadata_filter 進行過濾。
|
| 31 |
+
"""
|
| 32 |
+
pass
|
src/retrievers/bm25_retriever.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
BM25 檢索器模組
|
| 3 |
+
"""
|
| 4 |
+
from typing import List, Dict, Optional
|
| 5 |
+
from rank_bm25 import BM25Okapi
|
| 6 |
+
import re
|
| 7 |
+
from .base import BaseRetriever
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class BM25Retriever(BaseRetriever):
|
| 11 |
+
"""使用 BM25 演算法進行關鍵字檢索"""
|
| 12 |
+
|
| 13 |
+
def __init__(self, documents: List[Dict]):
|
| 14 |
+
"""
|
| 15 |
+
初始化 BM25 檢索器
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
documents: 文檔列表,每個文檔包含 "content" 和 "metadata"
|
| 19 |
+
"""
|
| 20 |
+
self.documents = documents
|
| 21 |
+
self.texts = [doc["content"] for doc in documents]
|
| 22 |
+
|
| 23 |
+
# 對文字進行 tokenization(簡單的分詞)
|
| 24 |
+
tokenized_texts = [self._tokenize(text) for text in self.texts]
|
| 25 |
+
|
| 26 |
+
# 初始化 BM25
|
| 27 |
+
self.bm25 = BM25Okapi(tokenized_texts)
|
| 28 |
+
|
| 29 |
+
def _tokenize(self, text: str) -> List[str]:
|
| 30 |
+
"""
|
| 31 |
+
將文字轉換為 tokens(簡單的實作)
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
text: 輸入文字
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
token 列表
|
| 38 |
+
"""
|
| 39 |
+
# 轉為小寫並分割
|
| 40 |
+
text = text.lower()
|
| 41 |
+
# 使用正則表達式分割(保留字母和數字)
|
| 42 |
+
tokens = re.findall(r'\b\w+\b', text)
|
| 43 |
+
return tokens
|
| 44 |
+
|
| 45 |
+
def retrieve(
|
| 46 |
+
self,
|
| 47 |
+
query: str,
|
| 48 |
+
top_k: int = 5,
|
| 49 |
+
metadata_filter: Optional[Dict] = None
|
| 50 |
+
) -> List[Dict]:
|
| 51 |
+
"""
|
| 52 |
+
檢索相關文檔,支援根據 metadata 進行過濾。
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
query: 查詢文字
|
| 56 |
+
top_k: 返回前 k 個結果
|
| 57 |
+
metadata_filter: 可選的 metadata 過濾條件字典。
|
| 58 |
+
例如: {"arxiv_id": "1234.5678"} 只檢索特定論文的 chunks
|
| 59 |
+
或 {"title": "Machine Learning"} 只檢索特定標題的論文
|
| 60 |
+
支援多個條件,所有條件必須同時滿足(AND 邏輯)
|
| 61 |
+
注意:BM25 的過濾是在檢索後進行的,所以可能會返回少於 top_k 的結果
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
相關文檔列表,每個包含 "content", "metadata", "score"
|
| 65 |
+
結果會根據 metadata_filter 進行過濾
|
| 66 |
+
"""
|
| 67 |
+
# Tokenize 查詢
|
| 68 |
+
tokenized_query = self._tokenize(query)
|
| 69 |
+
|
| 70 |
+
# 計算 BM25 分數
|
| 71 |
+
scores = self.bm25.get_scores(tokenized_query)
|
| 72 |
+
|
| 73 |
+
# 獲取所有結果並排序(先獲取更多結果以應對過濾後可能減少的情況)
|
| 74 |
+
# 如果沒有過濾條件,只需要 top_k 個;如果有過濾條件,需要更多候選結果
|
| 75 |
+
candidate_k = top_k * 3 if metadata_filter else top_k
|
| 76 |
+
|
| 77 |
+
# 獲取候選結果索引(按分數降序排列)
|
| 78 |
+
sorted_indices = sorted(
|
| 79 |
+
range(len(scores)),
|
| 80 |
+
key=lambda i: scores[i],
|
| 81 |
+
reverse=True
|
| 82 |
+
)[:candidate_k]
|
| 83 |
+
|
| 84 |
+
# 構建候選結果
|
| 85 |
+
candidate_results = []
|
| 86 |
+
for idx in sorted_indices:
|
| 87 |
+
candidate_results.append({
|
| 88 |
+
"content": self.documents[idx]["content"],
|
| 89 |
+
"metadata": self.documents[idx]["metadata"],
|
| 90 |
+
"score": float(scores[idx]),
|
| 91 |
+
})
|
| 92 |
+
|
| 93 |
+
# 如果提供了 metadata_filter,則進行過濾
|
| 94 |
+
if metadata_filter:
|
| 95 |
+
filtered_results = []
|
| 96 |
+
for result in candidate_results:
|
| 97 |
+
# 檢查該結果的 metadata 是否滿足所有過濾條件
|
| 98 |
+
metadata = result.get("metadata", {})
|
| 99 |
+
matches_all = True
|
| 100 |
+
|
| 101 |
+
for filter_key, filter_value in metadata_filter.items():
|
| 102 |
+
# 獲取文檔中對應的 metadata 值
|
| 103 |
+
doc_value = metadata.get(filter_key)
|
| 104 |
+
|
| 105 |
+
# 檢查是否匹配
|
| 106 |
+
# 支援精確匹配和部分匹配(如果 filter_value 是字串且 doc_value 也是字串)
|
| 107 |
+
if isinstance(filter_value, str) and isinstance(doc_value, str):
|
| 108 |
+
# 字串匹配:支援精確匹配或包含匹配
|
| 109 |
+
if filter_value.lower() not in doc_value.lower():
|
| 110 |
+
matches_all = False
|
| 111 |
+
break
|
| 112 |
+
else:
|
| 113 |
+
# 其他類型(數字、布林值等)使用精確匹配
|
| 114 |
+
if doc_value != filter_value:
|
| 115 |
+
matches_all = False
|
| 116 |
+
break
|
| 117 |
+
|
| 118 |
+
# 如果所有條件都滿足,則加入結果
|
| 119 |
+
if matches_all:
|
| 120 |
+
filtered_results.append(result)
|
| 121 |
+
|
| 122 |
+
# 返回過濾後的結果(最多 top_k 個)
|
| 123 |
+
return filtered_results[:top_k]
|
| 124 |
+
else:
|
| 125 |
+
# 沒有過濾條件,直接返回候選結果
|
| 126 |
+
return candidate_results
|
| 127 |
+
|
src/retrievers/hybrid_search.py
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hybrid Search 模組:結合 BM25 和向量檢索
|
| 3 |
+
支援兩種融合方法:加權求和(Weighted Sum)和倒數排名融合(RRF)
|
| 4 |
+
"""
|
| 5 |
+
from typing import List, Dict, Optional, Literal
|
| 6 |
+
from .base import BaseRetriever
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class HybridSearch(BaseRetriever):
|
| 11 |
+
"""結合稀疏和密集檢索的混合搜尋"""
|
| 12 |
+
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
sparse_retriever: BaseRetriever,
|
| 16 |
+
dense_retriever: BaseRetriever,
|
| 17 |
+
sparse_weight: float = 0.4,
|
| 18 |
+
dense_weight: float = 0.6,
|
| 19 |
+
fusion_method: Literal["weighted_sum", "rrf"] = "rrf",
|
| 20 |
+
rrf_k: int = 60,
|
| 21 |
+
):
|
| 22 |
+
"""
|
| 23 |
+
初始化 Hybrid Search
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
sparse_retriever: 稀疏檢索器 (例如 BM25)
|
| 27 |
+
dense_retriever: 密集檢索器 (例如向量檢索)
|
| 28 |
+
sparse_weight: 稀疏檢索分數的權重(僅用於 weighted_sum 方法)
|
| 29 |
+
dense_weight: 密集檢索分數的權重(僅用於 weighted_sum 方法)
|
| 30 |
+
fusion_method: 融合方法,可選 "weighted_sum" 或 "rrf"
|
| 31 |
+
- "weighted_sum": 加權求和,需要正規化分數並設置權重
|
| 32 |
+
- "rrf": 倒數排名融合(Reciprocal Rank Fusion),
|
| 33 |
+
不需要分數正規化,對不同分數分佈更魯棒
|
| 34 |
+
rrf_k: RRF 方法中的常數 k,通常設為 60(僅用於 rrf 方法)
|
| 35 |
+
較大的 k 值會讓排名較低的文檔獲得更多權重
|
| 36 |
+
"""
|
| 37 |
+
self.sparse_retriever = sparse_retriever
|
| 38 |
+
self.dense_retriever = dense_retriever
|
| 39 |
+
self.fusion_method = fusion_method
|
| 40 |
+
self.rrf_k = rrf_k
|
| 41 |
+
|
| 42 |
+
# 僅在 weighted_sum 方法中使用權重
|
| 43 |
+
if fusion_method == "weighted_sum":
|
| 44 |
+
self.sparse_weight = sparse_weight
|
| 45 |
+
self.dense_weight = dense_weight
|
| 46 |
+
|
| 47 |
+
# 確保權重總和為 1
|
| 48 |
+
total_weight = sparse_weight + dense_weight
|
| 49 |
+
if abs(total_weight - 1.0) > 1e-6:
|
| 50 |
+
self.sparse_weight = sparse_weight / total_weight
|
| 51 |
+
self.dense_weight = dense_weight / total_weight
|
| 52 |
+
|
| 53 |
+
def _normalize_scores(self, results: List[Dict]) -> List[Dict]:
|
| 54 |
+
"""
|
| 55 |
+
將分數正規化到 [0, 1] 區間。
|
| 56 |
+
僅用於 weighted_sum 方法。
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
results: 檢索結果列表,每個字典包含 'score'
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
帶有正規化分數的結果列表
|
| 63 |
+
"""
|
| 64 |
+
scores = [res.get("score", 0.0) for res in results]
|
| 65 |
+
if not scores:
|
| 66 |
+
return results
|
| 67 |
+
|
| 68 |
+
scores_array = np.array(scores)
|
| 69 |
+
min_score = scores_array.min()
|
| 70 |
+
max_score = scores_array.max()
|
| 71 |
+
|
| 72 |
+
if max_score == min_score:
|
| 73 |
+
# 如果所有分數都相同,將它們設置為 1.0
|
| 74 |
+
normalized_scores = [1.0] * len(scores)
|
| 75 |
+
else:
|
| 76 |
+
normalized_scores = ((scores_array - min_score) / (max_score - min_score)).tolist()
|
| 77 |
+
|
| 78 |
+
for i, res in enumerate(results):
|
| 79 |
+
res["score"] = normalized_scores[i]
|
| 80 |
+
|
| 81 |
+
return results
|
| 82 |
+
|
| 83 |
+
def _get_doc_id(self, doc: Dict) -> str:
|
| 84 |
+
"""
|
| 85 |
+
從文檔中提取唯一標識符
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
doc: 文檔字典
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
文檔的唯一 ID
|
| 92 |
+
"""
|
| 93 |
+
metadata = doc.get("metadata", {})
|
| 94 |
+
return f"{metadata.get('arxiv_id', 'unknown')}_{metadata.get('chunk_index', 0)}"
|
| 95 |
+
|
| 96 |
+
def _apply_rrf(
|
| 97 |
+
self,
|
| 98 |
+
sparse_results: List[Dict],
|
| 99 |
+
dense_results: List[Dict]
|
| 100 |
+
) -> List[Dict]:
|
| 101 |
+
"""
|
| 102 |
+
應用倒數排名融合(Reciprocal Rank Fusion, RRF)方法
|
| 103 |
+
|
| 104 |
+
RRF 公式:RRF(d) = Σ(1 / (k + rank_i(d)))
|
| 105 |
+
其中:
|
| 106 |
+
- d 是文檔
|
| 107 |
+
- rank_i(d) 是文檔在第 i 個檢索結果中的排名(從 1 開始)
|
| 108 |
+
- k 是常數(預設為 60)
|
| 109 |
+
|
| 110 |
+
RRF 的優點:
|
| 111 |
+
1. 不需要分數正規化,對不同分數分佈的檢索器更魯棒
|
| 112 |
+
2. 只依賴排名位置,不依賴分數值
|
| 113 |
+
3. 自動處理分數分佈差異的問題
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
sparse_results: 稀疏檢索結果列表
|
| 117 |
+
dense_results: 密集檢索結果列表
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
融合後的結果列表,按 RRF 分數排序
|
| 121 |
+
"""
|
| 122 |
+
# 建立文檔 ID 到 RRF 分數的映射
|
| 123 |
+
doc_to_rrf_score = {}
|
| 124 |
+
|
| 125 |
+
# 處理稀疏檢索結果(BM25)
|
| 126 |
+
for rank, result in enumerate(sparse_results, start=1):
|
| 127 |
+
doc_id = self._get_doc_id(result)
|
| 128 |
+
if doc_id not in doc_to_rrf_score:
|
| 129 |
+
doc_to_rrf_score[doc_id] = {
|
| 130 |
+
"doc": result,
|
| 131 |
+
"rrf_score": 0.0,
|
| 132 |
+
"sparse_rank": None,
|
| 133 |
+
"dense_rank": None
|
| 134 |
+
}
|
| 135 |
+
# 計算 RRF 貢獻:1 / (k + rank)
|
| 136 |
+
doc_to_rrf_score[doc_id]["rrf_score"] += 1.0 / (self.rrf_k + rank)
|
| 137 |
+
doc_to_rrf_score[doc_id]["sparse_rank"] = rank
|
| 138 |
+
|
| 139 |
+
# 處理密集檢索結果(向量)
|
| 140 |
+
for rank, result in enumerate(dense_results, start=1):
|
| 141 |
+
doc_id = self._get_doc_id(result)
|
| 142 |
+
if doc_id not in doc_to_rrf_score:
|
| 143 |
+
doc_to_rrf_score[doc_id] = {
|
| 144 |
+
"doc": result,
|
| 145 |
+
"rrf_score": 0.0,
|
| 146 |
+
"sparse_rank": None,
|
| 147 |
+
"dense_rank": None
|
| 148 |
+
}
|
| 149 |
+
# 計算 RRF 貢獻:1 / (k + rank)
|
| 150 |
+
doc_to_rrf_score[doc_id]["rrf_score"] += 1.0 / (self.rrf_k + rank)
|
| 151 |
+
doc_to_rrf_score[doc_id]["dense_rank"] = rank
|
| 152 |
+
|
| 153 |
+
# 構建結果列表
|
| 154 |
+
rrf_results = []
|
| 155 |
+
for doc_id, data in doc_to_rrf_score.items():
|
| 156 |
+
result = data["doc"].copy()
|
| 157 |
+
result["hybrid_score"] = data["rrf_score"]
|
| 158 |
+
result["rrf_score"] = data["rrf_score"]
|
| 159 |
+
result["sparse_rank"] = data["sparse_rank"]
|
| 160 |
+
result["dense_rank"] = data["dense_rank"]
|
| 161 |
+
|
| 162 |
+
# 從原始結果中獲取分數以供參考
|
| 163 |
+
if data["sparse_rank"] is not None:
|
| 164 |
+
# 從稀疏檢索結果中獲取原始分數
|
| 165 |
+
for sparse_res in sparse_results:
|
| 166 |
+
if self._get_doc_id(sparse_res) == doc_id:
|
| 167 |
+
result["sparse_score"] = sparse_res.get("score", 0.0)
|
| 168 |
+
break
|
| 169 |
+
else:
|
| 170 |
+
result["sparse_score"] = None
|
| 171 |
+
|
| 172 |
+
if data["dense_rank"] is not None:
|
| 173 |
+
# 從密集檢索結果中獲取原始分數
|
| 174 |
+
for dense_res in dense_results:
|
| 175 |
+
if self._get_doc_id(dense_res) == doc_id:
|
| 176 |
+
result["dense_score"] = dense_res.get("score", 0.0)
|
| 177 |
+
break
|
| 178 |
+
else:
|
| 179 |
+
result["dense_score"] = None
|
| 180 |
+
|
| 181 |
+
rrf_results.append(result)
|
| 182 |
+
|
| 183 |
+
# 按 RRF 分數從高到低排序
|
| 184 |
+
rrf_results.sort(key=lambda x: x["rrf_score"], reverse=True)
|
| 185 |
+
|
| 186 |
+
return rrf_results
|
| 187 |
+
|
| 188 |
+
def _apply_weighted_sum(
|
| 189 |
+
self,
|
| 190 |
+
sparse_results: List[Dict],
|
| 191 |
+
dense_results: List[Dict]
|
| 192 |
+
) -> List[Dict]:
|
| 193 |
+
"""
|
| 194 |
+
應用加權求和(Weighted Sum)方法
|
| 195 |
+
|
| 196 |
+
此方法需要:
|
| 197 |
+
1. 正規化兩組分數到相同範圍
|
| 198 |
+
2. 根據權重進行加權求和
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
sparse_results: 稀疏檢索結果列表
|
| 202 |
+
dense_results: 密集檢索結果列表
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
融合後的結果列表,按混合分數排序
|
| 206 |
+
"""
|
| 207 |
+
# 正規化兩組分數
|
| 208 |
+
normalized_sparse = self._normalize_scores(sparse_results)
|
| 209 |
+
normalized_dense = self._normalize_scores(dense_results)
|
| 210 |
+
|
| 211 |
+
# 結合分數
|
| 212 |
+
doc_to_scores = {}
|
| 213 |
+
|
| 214 |
+
# 處理稀疏檢索結果
|
| 215 |
+
for res in normalized_sparse:
|
| 216 |
+
doc_id = self._get_doc_id(res)
|
| 217 |
+
if doc_id not in doc_to_scores:
|
| 218 |
+
doc_to_scores[doc_id] = {"doc": res, "sparse": 0.0, "dense": 0.0}
|
| 219 |
+
doc_to_scores[doc_id]["sparse"] = res["score"]
|
| 220 |
+
|
| 221 |
+
# 處理密集檢索結果
|
| 222 |
+
for res in normalized_dense:
|
| 223 |
+
doc_id = self._get_doc_id(res)
|
| 224 |
+
if doc_id not in doc_to_scores:
|
| 225 |
+
doc_to_scores[doc_id] = {"doc": res, "sparse": 0.0, "dense": 0.0}
|
| 226 |
+
doc_to_scores[doc_id]["dense"] = res["score"]
|
| 227 |
+
|
| 228 |
+
# 計算混合分數並排序
|
| 229 |
+
hybrid_results = []
|
| 230 |
+
for doc_id, scores in doc_to_scores.items():
|
| 231 |
+
hybrid_score = (
|
| 232 |
+
self.sparse_weight * scores["sparse"] +
|
| 233 |
+
self.dense_weight * scores["dense"]
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
result = scores["doc"].copy()
|
| 237 |
+
result["hybrid_score"] = hybrid_score
|
| 238 |
+
result["sparse_score"] = scores["sparse"]
|
| 239 |
+
result["dense_score"] = scores["dense"]
|
| 240 |
+
hybrid_results.append(result)
|
| 241 |
+
|
| 242 |
+
# 按混合分數從高到低排序
|
| 243 |
+
hybrid_results.sort(key=lambda x: x["hybrid_score"], reverse=True)
|
| 244 |
+
|
| 245 |
+
return hybrid_results
|
| 246 |
+
|
| 247 |
+
def retrieve(
|
| 248 |
+
self,
|
| 249 |
+
query: str,
|
| 250 |
+
top_k: int = 5,
|
| 251 |
+
metadata_filter: Optional[Dict] = None
|
| 252 |
+
) -> List[Dict]:
|
| 253 |
+
"""
|
| 254 |
+
執行混合搜尋,支援根據 metadata 進行過濾
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
query: 查詢文字
|
| 258 |
+
top_k: 返回前 k 個結果
|
| 259 |
+
metadata_filter: 可選的 metadata 過濾條件字典。
|
| 260 |
+
例如: {"arxiv_id": "1234.5678"} 只檢索特定論文的 chunks
|
| 261 |
+
或 {"title": "Machine Learning"} 只檢索特定標題的論文
|
| 262 |
+
支援多個條件,所有條件必須同時滿足(AND 邏輯)
|
| 263 |
+
此過濾條件會傳遞給底層的稀疏和密集檢索器
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
相關文檔列表,每個包含 "content", "metadata", "hybrid_score"
|
| 267 |
+
結果會根據 metadata_filter 進行過濾
|
| 268 |
+
|
| 269 |
+
根據 fusion_method 的不同,返回的結果會包含不同的分數欄位:
|
| 270 |
+
- RRF 方法:包含 "rrf_score", "sparse_rank", "dense_rank"
|
| 271 |
+
- Weighted Sum 方法:包含 "sparse_score", "dense_score"
|
| 272 |
+
"""
|
| 273 |
+
# 1. 從兩個檢索器獲取結果(請求更多結果以確保覆蓋率)
|
| 274 |
+
# 將 metadata_filter 傳遞給底層檢索器
|
| 275 |
+
sparse_results = self.sparse_retriever.retrieve(
|
| 276 |
+
query,
|
| 277 |
+
top_k=top_k * 2,
|
| 278 |
+
metadata_filter=metadata_filter
|
| 279 |
+
)
|
| 280 |
+
dense_results = self.dense_retriever.retrieve(
|
| 281 |
+
query,
|
| 282 |
+
top_k=top_k * 2,
|
| 283 |
+
metadata_filter=metadata_filter
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# 2. 根據選擇的融合方法進行結果融合
|
| 287 |
+
if self.fusion_method == "rrf":
|
| 288 |
+
# 使用 RRF(倒數排名融合)方法
|
| 289 |
+
# RRF 不需要分數正規化,直接基於排名進行融合
|
| 290 |
+
hybrid_results = self._apply_rrf(sparse_results, dense_results)
|
| 291 |
+
else:
|
| 292 |
+
# 使用加權求和方法
|
| 293 |
+
# 需要先正規化分數,然後根據權重進行加權求和
|
| 294 |
+
hybrid_results = self._apply_weighted_sum(sparse_results, dense_results)
|
| 295 |
+
|
| 296 |
+
# 3. 返回前 top_k 個結果
|
| 297 |
+
return hybrid_results[:top_k]
|
| 298 |
+
|
src/retrievers/reranker.py
ADDED
|
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
重排序模組:使用 Cross-Encoder 進行精準重排
|
| 3 |
+
"""
|
| 4 |
+
from typing import List, Dict, Optional, Tuple
|
| 5 |
+
from sentence_transformers import CrossEncoder
|
| 6 |
+
import time
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
# 嘗試導入 torch 來檢測可用的設備
|
| 10 |
+
try:
|
| 11 |
+
import torch
|
| 12 |
+
TORCH_AVAILABLE = True
|
| 13 |
+
except ImportError:
|
| 14 |
+
TORCH_AVAILABLE = False
|
| 15 |
+
|
| 16 |
+
# 配置日志
|
| 17 |
+
logging.basicConfig(level=logging.INFO)
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_device() -> str:
|
| 22 |
+
"""
|
| 23 |
+
自動檢測並返回最佳可用的設備
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
設備名稱: 'mps' (macOS GPU), 'cuda' (NVIDIA GPU), 或 'cpu'
|
| 27 |
+
"""
|
| 28 |
+
if not TORCH_AVAILABLE:
|
| 29 |
+
return 'cpu'
|
| 30 |
+
|
| 31 |
+
# 優先順序: MPS (macOS) > CUDA (NVIDIA) > CPU
|
| 32 |
+
if torch.backends.mps.is_available():
|
| 33 |
+
return 'mps'
|
| 34 |
+
elif torch.cuda.is_available():
|
| 35 |
+
return 'cuda'
|
| 36 |
+
else:
|
| 37 |
+
return 'cpu'
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class Reranker:
|
| 41 |
+
"""重排序組件:使用 Cross-Encoder 進行精準重排"""
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
model_name: str = "BAAI/bge-reranker-base",
|
| 46 |
+
device: str = None,
|
| 47 |
+
max_length: int = 512,
|
| 48 |
+
batch_size: int = 32,
|
| 49 |
+
enable_cache: bool = True
|
| 50 |
+
):
|
| 51 |
+
"""
|
| 52 |
+
初始化 Cross-Encoder 模型
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
model_name: Cross-Encoder 模型名稱
|
| 56 |
+
device: 設備名稱 ('cuda', 'cpu', 'mps')
|
| 57 |
+
max_length: 最大 token 長度(模型限制)
|
| 58 |
+
batch_size: 批處理大小,用於優化內存使用
|
| 59 |
+
enable_cache: 是否啟用模型緩存
|
| 60 |
+
"""
|
| 61 |
+
try:
|
| 62 |
+
# 自動檢測設備(如果未指定)
|
| 63 |
+
if device is None:
|
| 64 |
+
device = get_device()
|
| 65 |
+
|
| 66 |
+
device_name_map = {
|
| 67 |
+
'mps': 'MPS (macOS GPU)',
|
| 68 |
+
'cuda': 'CUDA (NVIDIA GPU)',
|
| 69 |
+
'cpu': 'CPU'
|
| 70 |
+
}
|
| 71 |
+
device_display = device_name_map.get(device, device)
|
| 72 |
+
|
| 73 |
+
self.model = CrossEncoder(
|
| 74 |
+
model_name,
|
| 75 |
+
device=device,
|
| 76 |
+
max_length=max_length
|
| 77 |
+
)
|
| 78 |
+
self.max_length = max_length
|
| 79 |
+
self.batch_size = batch_size
|
| 80 |
+
self.model_name = model_name
|
| 81 |
+
logger.info(f"✅ 重排模型 {model_name} 已載入 (device: {device_display})")
|
| 82 |
+
except Exception as e:
|
| 83 |
+
logger.error(f"❌ 模型載入失敗: {e}")
|
| 84 |
+
raise
|
| 85 |
+
|
| 86 |
+
def _truncate_text(self, text: str, max_chars: int = 2000) -> str:
|
| 87 |
+
"""
|
| 88 |
+
截斷過長的文本(粗略估計,避免超過 token 限制)
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
text: 原始文本
|
| 92 |
+
max_chars: 最大字符數(保守估計,約 500 tokens)
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
截斷後的文本
|
| 96 |
+
"""
|
| 97 |
+
if len(text) <= max_chars:
|
| 98 |
+
return text
|
| 99 |
+
# 截斷並添加省略號
|
| 100 |
+
return text[:max_chars - 3] + "..."
|
| 101 |
+
|
| 102 |
+
def _prepare_pairs(
|
| 103 |
+
self,
|
| 104 |
+
query: str,
|
| 105 |
+
documents: List[Dict]
|
| 106 |
+
) -> List[Tuple[str, str]]:
|
| 107 |
+
"""
|
| 108 |
+
準備 (query, document) 配對,處理文本長度
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
query: 查詢文本
|
| 112 |
+
documents: 文檔列表
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
(query, content) 配對列表
|
| 116 |
+
"""
|
| 117 |
+
pairs = []
|
| 118 |
+
truncated_indices = [] # 記錄哪些文檔被截斷了
|
| 119 |
+
|
| 120 |
+
# 粗略估計:每個字符約 0.25 tokens,為 query 預留空間
|
| 121 |
+
max_doc_chars = int((self.max_length * 0.7) - len(query))
|
| 122 |
+
|
| 123 |
+
for i, doc in enumerate(documents):
|
| 124 |
+
content = doc.get("content", "")
|
| 125 |
+
original_length = len(content)
|
| 126 |
+
|
| 127 |
+
# 如果內容過長,進行截斷
|
| 128 |
+
if len(content) > max_doc_chars:
|
| 129 |
+
content = self._truncate_text(content, max_doc_chars)
|
| 130 |
+
truncated_indices.append(i)
|
| 131 |
+
|
| 132 |
+
pairs.append([query, content])
|
| 133 |
+
|
| 134 |
+
if truncated_indices:
|
| 135 |
+
logger.warning(
|
| 136 |
+
f"⚠️ 有 {len(truncated_indices)} 個文檔因過長被截斷 "
|
| 137 |
+
f"(最大長度: {max_doc_chars} 字符)"
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
return pairs
|
| 141 |
+
|
| 142 |
+
def rerank(
|
| 143 |
+
self,
|
| 144 |
+
query: str,
|
| 145 |
+
documents: List[Dict],
|
| 146 |
+
top_k: int = 5,
|
| 147 |
+
preserve_original_scores: bool = True
|
| 148 |
+
) -> List[Dict]:
|
| 149 |
+
"""
|
| 150 |
+
執行精準重排
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
query: 查詢文本
|
| 154 |
+
documents: 文檔列表,每個應包含 "content" 和可選的 "hybrid_score"
|
| 155 |
+
top_k: 返回前 k 個結果
|
| 156 |
+
preserve_original_scores: 是否保留原始分數(hybrid_score)
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
重排後的文檔列表,按 rerank_score 降序排列
|
| 160 |
+
"""
|
| 161 |
+
if not documents:
|
| 162 |
+
logger.warning("⚠️ 文檔列表為空,返回空結果")
|
| 163 |
+
return []
|
| 164 |
+
|
| 165 |
+
if not query or not query.strip():
|
| 166 |
+
logger.warning("⚠️ 查詢為空,返回原始文檔順序")
|
| 167 |
+
return documents[:top_k]
|
| 168 |
+
|
| 169 |
+
start_time = time.time()
|
| 170 |
+
logger.info(f"🔄 開始重排 {len(documents)} 個文檔...")
|
| 171 |
+
|
| 172 |
+
try:
|
| 173 |
+
# 1. 準備配對
|
| 174 |
+
pairs = self._prepare_pairs(query, documents)
|
| 175 |
+
|
| 176 |
+
# 2. 批處理計算分數(優化內存使用)
|
| 177 |
+
scores = []
|
| 178 |
+
for i in range(0, len(pairs), self.batch_size):
|
| 179 |
+
batch_pairs = pairs[i:i + self.batch_size]
|
| 180 |
+
batch_scores = self.model.predict(batch_pairs)
|
| 181 |
+
scores.extend(batch_scores.tolist() if hasattr(batch_scores, 'tolist') else batch_scores)
|
| 182 |
+
|
| 183 |
+
# 3. 更新文檔分數
|
| 184 |
+
for i, doc in enumerate(documents):
|
| 185 |
+
doc = doc.copy() # 避免修改原始文檔
|
| 186 |
+
doc["rerank_score"] = float(scores[i])
|
| 187 |
+
|
| 188 |
+
# 保留原始分數供參考
|
| 189 |
+
if preserve_original_scores:
|
| 190 |
+
if "hybrid_score" not in doc:
|
| 191 |
+
# 如果沒有 hybrid_score,嘗試使用其他分數
|
| 192 |
+
doc["original_score"] = doc.get("score", 0.0)
|
| 193 |
+
|
| 194 |
+
documents[i] = doc
|
| 195 |
+
|
| 196 |
+
# 4. 根據 rerank_score 重新排序
|
| 197 |
+
reranked_docs = sorted(
|
| 198 |
+
documents,
|
| 199 |
+
key=lambda x: x.get("rerank_score", float('-inf')),
|
| 200 |
+
reverse=True
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# 5. 統計資訊
|
| 204 |
+
elapsed_time = time.time() - start_time
|
| 205 |
+
avg_score = sum(scores) / len(scores) if scores else 0.0
|
| 206 |
+
max_score = max(scores) if scores else 0.0
|
| 207 |
+
min_score = min(scores) if scores else 0.0
|
| 208 |
+
|
| 209 |
+
logger.info(
|
| 210 |
+
f"✅ 重排完成 (耗時: {elapsed_time:.2f}s, "
|
| 211 |
+
f"平均分數: {avg_score:.4f}, "
|
| 212 |
+
f"範圍: [{min_score:.4f}, {max_score:.4f}])"
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
return reranked_docs[:top_k]
|
| 216 |
+
|
| 217 |
+
except Exception as e:
|
| 218 |
+
logger.error(f"❌ 重排過程出錯: {e}")
|
| 219 |
+
# 降級策略:返回原始順序的前 top_k 個
|
| 220 |
+
logger.warning("⚠️ 使用降級策略:返回原始順序")
|
| 221 |
+
return documents[:top_k]
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class RAGPipeline:
|
| 225 |
+
"""協調管線:管理完整的 RAG 流程(召回 + 重排)"""
|
| 226 |
+
|
| 227 |
+
def __init__(
|
| 228 |
+
self,
|
| 229 |
+
hybrid_search,
|
| 230 |
+
reranker,
|
| 231 |
+
recall_k: int = 25,
|
| 232 |
+
adaptive_recall: bool = True,
|
| 233 |
+
min_recall_k: int = 10,
|
| 234 |
+
max_recall_k: int = 50
|
| 235 |
+
):
|
| 236 |
+
"""
|
| 237 |
+
初始化 RAG 管線
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
hybrid_search: HybridSearch 實例
|
| 241 |
+
reranker: Reranker 實例
|
| 242 |
+
recall_k: 第一階段召回的數量(預設值)
|
| 243 |
+
adaptive_recall: 是否根據查詢動態調整 recall_k
|
| 244 |
+
min_recall_k: 最小召回數量
|
| 245 |
+
max_recall_k: 最大召回數量
|
| 246 |
+
"""
|
| 247 |
+
self.hybrid_search = hybrid_search
|
| 248 |
+
self.reranker = reranker
|
| 249 |
+
self.base_recall_k = recall_k
|
| 250 |
+
self.adaptive_recall = adaptive_recall
|
| 251 |
+
self.min_recall_k = min_recall_k
|
| 252 |
+
self.max_recall_k = max_recall_k
|
| 253 |
+
|
| 254 |
+
# 性能統計
|
| 255 |
+
self.stats = {
|
| 256 |
+
"total_queries": 0,
|
| 257 |
+
"avg_recall_time": 0.0,
|
| 258 |
+
"avg_rerank_time": 0.0,
|
| 259 |
+
"avg_total_time": 0.0
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
def _calculate_adaptive_recall_k(self, query: str) -> int:
|
| 263 |
+
"""
|
| 264 |
+
根據查詢複雜度動態計算 recall_k
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
query: 查詢文本
|
| 268 |
+
|
| 269 |
+
Returns:
|
| 270 |
+
調整後的 recall_k
|
| 271 |
+
"""
|
| 272 |
+
if not self.adaptive_recall:
|
| 273 |
+
return self.base_recall_k
|
| 274 |
+
|
| 275 |
+
# 簡單啟發式:根據查詢長度和關鍵詞數量調整
|
| 276 |
+
query_length = len(query.split())
|
| 277 |
+
keyword_count = len(set(query.lower().split()))
|
| 278 |
+
|
| 279 |
+
# 複雜查詢需要更多候選
|
| 280 |
+
if query_length > 10 or keyword_count > 5:
|
| 281 |
+
recall_k = min(self.base_recall_k * 2, self.max_recall_k)
|
| 282 |
+
elif query_length < 3:
|
| 283 |
+
recall_k = max(self.base_recall_k // 2, self.min_recall_k)
|
| 284 |
+
else:
|
| 285 |
+
recall_k = self.base_recall_k
|
| 286 |
+
|
| 287 |
+
return recall_k
|
| 288 |
+
|
| 289 |
+
def query(
|
| 290 |
+
self,
|
| 291 |
+
text: str,
|
| 292 |
+
top_k: int = 5,
|
| 293 |
+
metadata_filter: Optional[Dict] = None,
|
| 294 |
+
enable_rerank: bool = True,
|
| 295 |
+
return_stats: bool = False
|
| 296 |
+
) -> List[Dict]:
|
| 297 |
+
"""
|
| 298 |
+
執行完整的搜尋流程
|
| 299 |
+
|
| 300 |
+
Args:
|
| 301 |
+
text: 查詢文本
|
| 302 |
+
top_k: 最終返回的結果數量
|
| 303 |
+
metadata_filter: 可選的 metadata 過濾條件
|
| 304 |
+
enable_rerank: 是否啟用重排序(可選,用於性能測試)
|
| 305 |
+
return_stats: 是否返回���能統計資訊
|
| 306 |
+
|
| 307 |
+
Returns:
|
| 308 |
+
相關文檔列表,如果 return_stats=True,則返回 (results, stats) 元組
|
| 309 |
+
"""
|
| 310 |
+
if not text or not text.strip():
|
| 311 |
+
logger.warning("⚠️ 查詢為空")
|
| 312 |
+
return []
|
| 313 |
+
|
| 314 |
+
total_start = time.time()
|
| 315 |
+
self.stats["total_queries"] += 1
|
| 316 |
+
|
| 317 |
+
# 動態計算 recall_k
|
| 318 |
+
recall_k = self._calculate_adaptive_recall_k(text)
|
| 319 |
+
logger.info(
|
| 320 |
+
f"🔍 搜尋中: '{text[:50]}...' "
|
| 321 |
+
f"(召回階段: {recall_k} 筆, 最終返回: {top_k} 筆)"
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
try:
|
| 325 |
+
# 第一階段:混合搜尋(召回階段)
|
| 326 |
+
recall_start = time.time()
|
| 327 |
+
initial_results = self.hybrid_search.retrieve(
|
| 328 |
+
query=text,
|
| 329 |
+
top_k=recall_k,
|
| 330 |
+
metadata_filter=metadata_filter
|
| 331 |
+
)
|
| 332 |
+
recall_time = time.time() - recall_start
|
| 333 |
+
|
| 334 |
+
if not initial_results:
|
| 335 |
+
logger.warning("⚠️ 召回階段未找到任何結果")
|
| 336 |
+
return []
|
| 337 |
+
|
| 338 |
+
logger.info(
|
| 339 |
+
f"✅ 召回階段完成: 找到 {len(initial_results)} 個候選 "
|
| 340 |
+
f"(耗時: {recall_time:.2f}s)"
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
# 第二階段:重排序(精篩階段)
|
| 344 |
+
if enable_rerank and len(initial_results) > top_k:
|
| 345 |
+
rerank_start = time.time()
|
| 346 |
+
final_results = self.reranker.rerank(
|
| 347 |
+
query=text,
|
| 348 |
+
documents=initial_results,
|
| 349 |
+
top_k=top_k
|
| 350 |
+
)
|
| 351 |
+
rerank_time = time.time() - rerank_start
|
| 352 |
+
|
| 353 |
+
logger.info(
|
| 354 |
+
f"✅ 重排階段完成: 從 {len(initial_results)} 個候選中選出 "
|
| 355 |
+
f"{len(final_results)} 個結果 (耗時: {rerank_time:.2f}s)"
|
| 356 |
+
)
|
| 357 |
+
else:
|
| 358 |
+
# 跳過重排序(用於性能測試或候選數較少時)
|
| 359 |
+
final_results = initial_results[:top_k]
|
| 360 |
+
rerank_time = 0.0
|
| 361 |
+
logger.info("⏭️ 跳過重排序階段(候選數不足或已禁用)")
|
| 362 |
+
|
| 363 |
+
# 更新統計資訊
|
| 364 |
+
total_time = time.time() - total_start
|
| 365 |
+
self._update_stats(recall_time, rerank_time, total_time)
|
| 366 |
+
|
| 367 |
+
# 添加性能資訊到結果(可選)
|
| 368 |
+
if return_stats:
|
| 369 |
+
stats = {
|
| 370 |
+
"recall_time": recall_time,
|
| 371 |
+
"rerank_time": rerank_time,
|
| 372 |
+
"total_time": total_time,
|
| 373 |
+
"recall_k": recall_k,
|
| 374 |
+
"candidates_found": len(initial_results),
|
| 375 |
+
"final_results": len(final_results)
|
| 376 |
+
}
|
| 377 |
+
return final_results, stats
|
| 378 |
+
|
| 379 |
+
return final_results
|
| 380 |
+
|
| 381 |
+
except Exception as e:
|
| 382 |
+
logger.error(f"❌ 查詢過程出錯: {e}")
|
| 383 |
+
# 降級策略:嘗試只使用召回階段
|
| 384 |
+
try:
|
| 385 |
+
logger.warning("⚠️ 嘗試降級策略:僅使用召回結果")
|
| 386 |
+
return self.hybrid_search.retrieve(text, top_k=top_k, metadata_filter=metadata_filter)
|
| 387 |
+
except Exception as e2:
|
| 388 |
+
logger.error(f"❌ 降級策略也失敗: {e2}")
|
| 389 |
+
return []
|
| 390 |
+
|
| 391 |
+
def _update_stats(self, recall_time: float, rerank_time: float, total_time: float):
|
| 392 |
+
"""更新性能統計資訊"""
|
| 393 |
+
n = self.stats["total_queries"]
|
| 394 |
+
self.stats["avg_recall_time"] = (
|
| 395 |
+
(self.stats["avg_recall_time"] * (n - 1) + recall_time) / n
|
| 396 |
+
)
|
| 397 |
+
self.stats["avg_rerank_time"] = (
|
| 398 |
+
(self.stats["avg_rerank_time"] * (n - 1) + rerank_time) / n
|
| 399 |
+
)
|
| 400 |
+
self.stats["avg_total_time"] = (
|
| 401 |
+
(self.stats["avg_total_time"] * (n - 1) + total_time) / n
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
def get_stats(self) -> Dict:
|
| 405 |
+
"""獲取性能統計資訊"""
|
| 406 |
+
return self.stats.copy()
|
| 407 |
+
|
| 408 |
+
def reset_stats(self):
|
| 409 |
+
"""重置統計資訊"""
|
| 410 |
+
self.stats = {
|
| 411 |
+
"total_queries": 0,
|
| 412 |
+
"avg_recall_time": 0.0,
|
| 413 |
+
"avg_rerank_time": 0.0,
|
| 414 |
+
"avg_total_time": 0.0
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
def format_results_for_llm(
|
| 418 |
+
self,
|
| 419 |
+
results: List[Dict],
|
| 420 |
+
format_style: str = "detailed"
|
| 421 |
+
) -> str:
|
| 422 |
+
"""
|
| 423 |
+
格式化檢索結果供 LLM 使用(需要導入 PromptFormatter)
|
| 424 |
+
|
| 425 |
+
Args:
|
| 426 |
+
results: 檢索結果列表
|
| 427 |
+
format_style: 格式風格 ("detailed", "simple", "minimal")
|
| 428 |
+
|
| 429 |
+
Returns:
|
| 430 |
+
格式化後的上下文字符串
|
| 431 |
+
"""
|
| 432 |
+
try:
|
| 433 |
+
from ..prompt_formatter import PromptFormatter
|
| 434 |
+
formatter = PromptFormatter(format_style=format_style)
|
| 435 |
+
return formatter.format_context(results)
|
| 436 |
+
except ImportError:
|
| 437 |
+
# 如果無法導入,使用簡單格式
|
| 438 |
+
formatted_parts = []
|
| 439 |
+
for i, result in enumerate(results, 1):
|
| 440 |
+
metadata = result.get("metadata", {})
|
| 441 |
+
content = result.get("content", "")
|
| 442 |
+
arxiv_id = metadata.get('arxiv_id', 'N/A')
|
| 443 |
+
title = metadata.get('title', 'N/A')
|
| 444 |
+
formatted_parts.append(
|
| 445 |
+
f"[來源 {i}: {title} (arXiv:{arxiv_id})]\n{content}\n"
|
| 446 |
+
)
|
| 447 |
+
return "\n" + "="*60 + "\n".join(formatted_parts)
|
| 448 |
+
|
src/retrievers/vector_retriever.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
向量檢索器模組:使用 embedding 和向量資料庫進行語義檢索
|
| 3 |
+
|
| 4 |
+
支援兩種初始化方式:
|
| 5 |
+
1. 自動初始化 embeddings(預設):根據參數創建新的 embedding 模型
|
| 6 |
+
2. 使用外部 embeddings:接收已初始化的 embedding 模型(可與 DocumentProcessor 共用)
|
| 7 |
+
"""
|
| 8 |
+
from typing import List, Dict, Optional, Any
|
| 9 |
+
from langchain_community.vectorstores import Chroma
|
| 10 |
+
from langchain_core.documents import Document
|
| 11 |
+
import os
|
| 12 |
+
from .base import BaseRetriever
|
| 13 |
+
|
| 14 |
+
# 嘗試導入 HuggingFaceEmbeddings(免費模型)
|
| 15 |
+
try:
|
| 16 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 17 |
+
except ImportError:
|
| 18 |
+
try:
|
| 19 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 20 |
+
except ImportError:
|
| 21 |
+
raise ImportError("需要安裝 langchain-community 或 langchain-huggingface 才能使用 Hugging Face embeddings")
|
| 22 |
+
|
| 23 |
+
# 導入 torch 來檢測可用的設備
|
| 24 |
+
try:
|
| 25 |
+
import torch
|
| 26 |
+
TORCH_AVAILABLE = True
|
| 27 |
+
except ImportError:
|
| 28 |
+
TORCH_AVAILABLE = False
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_device() -> str:
|
| 32 |
+
"""
|
| 33 |
+
自動檢測並返回最佳可用的設備
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
設備名稱: 'mps' (macOS GPU), 'cuda' (NVIDIA GPU), 或 'cpu'
|
| 37 |
+
"""
|
| 38 |
+
if not TORCH_AVAILABLE:
|
| 39 |
+
return 'cpu'
|
| 40 |
+
|
| 41 |
+
# 優先順序: MPS (macOS) > CUDA (NVIDIA) > CPU
|
| 42 |
+
if torch.backends.mps.is_available():
|
| 43 |
+
return 'mps'
|
| 44 |
+
elif torch.cuda.is_available():
|
| 45 |
+
return 'cuda'
|
| 46 |
+
else:
|
| 47 |
+
return 'cpu'
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class VectorRetriever(BaseRetriever):
|
| 51 |
+
"""使用向量檢索進行語義搜尋"""
|
| 52 |
+
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
documents: List[Dict],
|
| 56 |
+
embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2",
|
| 57 |
+
persist_directory: Optional[str] = "./chroma_db",
|
| 58 |
+
hf_cache_dir: Optional[str] = None,
|
| 59 |
+
device: Optional[str] = None,
|
| 60 |
+
embeddings: Optional[Any] = None # 可選:外部傳入的 embedding 模型(優先使用)
|
| 61 |
+
):
|
| 62 |
+
"""
|
| 63 |
+
初始化向量檢索器(使用 Hugging Face embeddings)
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
documents: 文檔列表,每個文檔包含 "content" 和 "metadata"
|
| 67 |
+
embedding_model: Hugging Face embedding 模型名稱(預設: "sentence-transformers/all-MiniLM-L6-v2")
|
| 68 |
+
僅在 embeddings=None 時使用
|
| 69 |
+
persist_directory: Chroma 資料庫持久化目錄
|
| 70 |
+
hf_cache_dir: Hugging Face 模型緩存目錄(例如外接硬碟路徑)
|
| 71 |
+
如果為 None,則使用環境變數 HF_HOME 或預設位置 ~/.cache/huggingface/
|
| 72 |
+
僅在 embeddings=None 時使用
|
| 73 |
+
device: 設備名稱 ('mps', 'cuda', 'cpu'),如果為 None 則自動檢測最佳設備
|
| 74 |
+
僅在 embeddings=None 時使用
|
| 75 |
+
embeddings: 可選的外部 embedding 模型物件
|
| 76 |
+
如果提供,將優先使用此模型,忽略其他參數(embedding_model, hf_cache_dir, device)
|
| 77 |
+
這允許與 DocumentProcessor 共用同一個 embedding 模型實例
|
| 78 |
+
優點:
|
| 79 |
+
- 節省內存(只加載一次模型)
|
| 80 |
+
- 節省時間(避免重複初始化)
|
| 81 |
+
- 確保一致性(分塊和檢索使用相同的模型)
|
| 82 |
+
"""
|
| 83 |
+
# 優先使用傳入的共用模型
|
| 84 |
+
if embeddings is not None:
|
| 85 |
+
self.embeddings = embeddings
|
| 86 |
+
print("✓ 使用外部傳入的 embeddings 模型(與 DocumentProcessor 共用)")
|
| 87 |
+
else:
|
| 88 |
+
# 若無傳入,則執行原有的初始化邏輯
|
| 89 |
+
print(f"使用 Hugging Face embedding 模型: {embedding_model}")
|
| 90 |
+
|
| 91 |
+
# 設置 Hugging Face 緩存目錄
|
| 92 |
+
if hf_cache_dir:
|
| 93 |
+
# 如果指定了緩存目錄,設置環境變數
|
| 94 |
+
os.environ['HF_HOME'] = hf_cache_dir
|
| 95 |
+
os.environ['TRANSFORMERS_CACHE'] = hf_cache_dir
|
| 96 |
+
print(f"模型將存儲在: {hf_cache_dir}")
|
| 97 |
+
else:
|
| 98 |
+
# 檢查是否已經設置了環境變數
|
| 99 |
+
default_cache = os.path.expanduser("~/.cache/huggingface")
|
| 100 |
+
current_cache = os.getenv('HF_HOME', default_cache)
|
| 101 |
+
print(f"模型緩存位置: {current_cache}")
|
| 102 |
+
print("提示: 可以通過設置 hf_cache_dir 參數或環境變數 HF_HOME 來指定外接硬碟路徑")
|
| 103 |
+
|
| 104 |
+
# 自動檢測或使用指定的設備
|
| 105 |
+
if device is None:
|
| 106 |
+
device = get_device()
|
| 107 |
+
|
| 108 |
+
device_name_map = {
|
| 109 |
+
'mps': 'MPS (macOS GPU)',
|
| 110 |
+
'cuda': 'CUDA (NVIDIA GPU)',
|
| 111 |
+
'cpu': 'CPU'
|
| 112 |
+
}
|
| 113 |
+
print(f"使用設備: {device_name_map.get(device, device)}")
|
| 114 |
+
print("首次使用時會下載模型,請稍候...")
|
| 115 |
+
|
| 116 |
+
# 構建 model_kwargs,包含緩存目錄和設備
|
| 117 |
+
model_kwargs = {'device': device}
|
| 118 |
+
if hf_cache_dir:
|
| 119 |
+
model_kwargs['cache_dir'] = hf_cache_dir
|
| 120 |
+
|
| 121 |
+
self.embeddings = HuggingFaceEmbeddings(
|
| 122 |
+
model_name=embedding_model,
|
| 123 |
+
model_kwargs=model_kwargs,
|
| 124 |
+
encode_kwargs={'normalize_embeddings': True} # 正規化 embeddings 以提升效果
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# 將文檔轉換為 LangChain Document 格式
|
| 128 |
+
# 需要將 metadata 中的列表轉換為字串,因為 ChromaDB 不接受列表類型
|
| 129 |
+
def sanitize_metadata(metadata: Dict) -> Dict:
|
| 130 |
+
"""將 metadata 中的列表轉換為字串,以符合 ChromaDB 的要求"""
|
| 131 |
+
sanitized = {}
|
| 132 |
+
for key, value in metadata.items():
|
| 133 |
+
if isinstance(value, list):
|
| 134 |
+
# 將列表轉換為逗號分隔的字串
|
| 135 |
+
sanitized[key] = ", ".join(str(v) for v in value)
|
| 136 |
+
elif isinstance(value, (dict, set)):
|
| 137 |
+
# 將字典或集合轉換為字串
|
| 138 |
+
sanitized[key] = str(value)
|
| 139 |
+
else:
|
| 140 |
+
# 其他類型(str, int, float, bool, None)直接保留
|
| 141 |
+
sanitized[key] = value
|
| 142 |
+
return sanitized
|
| 143 |
+
|
| 144 |
+
langchain_docs = [
|
| 145 |
+
Document(
|
| 146 |
+
page_content=doc["content"],
|
| 147 |
+
metadata=sanitize_metadata(doc["metadata"])
|
| 148 |
+
)
|
| 149 |
+
for doc in documents
|
| 150 |
+
]
|
| 151 |
+
|
| 152 |
+
# 創建向量資料庫
|
| 153 |
+
self.vectorstore = Chroma.from_documents(
|
| 154 |
+
documents=langchain_docs,
|
| 155 |
+
embedding=self.embeddings,
|
| 156 |
+
persist_directory=persist_directory
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# 創建 retriever
|
| 160 |
+
self.retriever = self.vectorstore.as_retriever()
|
| 161 |
+
|
| 162 |
+
def retrieve(
|
| 163 |
+
self,
|
| 164 |
+
query: str,
|
| 165 |
+
top_k: int = 5,
|
| 166 |
+
metadata_filter: Optional[Dict] = None
|
| 167 |
+
) -> List[Dict]:
|
| 168 |
+
"""
|
| 169 |
+
檢索相關文檔,並返回標準化的相似度分數(越高越好)。
|
| 170 |
+
支援根據 metadata 進行過濾。
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
query: 查詢文字
|
| 174 |
+
top_k: 返回前 k 個結果
|
| 175 |
+
metadata_filter: 可選的 metadata 過濾條件字典。
|
| 176 |
+
例如: {"arxiv_id": "1234.5678"} 只檢索特定論文的 chunks
|
| 177 |
+
或 {"title": "Machine Learning"} 只檢索特定標題的論文
|
| 178 |
+
支援多個條件,所有條件必須同時滿足(AND 邏輯)
|
| 179 |
+
注意:ChromaDB 的 where 條件支援精確匹配,不支援部分匹配
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
相關文檔列表,每個包含 "content", "metadata", 和 "score"
|
| 183 |
+
結果會根據 metadata_filter 進行過濾
|
| 184 |
+
"""
|
| 185 |
+
# 構建過濾條件
|
| 186 |
+
# 如果提供了 metadata_filter,先獲取更多結果,然後在 Python 中進行過濾
|
| 187 |
+
# 這是因為 LangChain ChromaDB 的 similarity_search_with_score 方法
|
| 188 |
+
# 對 filter 參數的支援可能因版本而異
|
| 189 |
+
if metadata_filter:
|
| 190 |
+
# 獲取更多結果以確保有足夠的候選進行過濾
|
| 191 |
+
results_with_scores = self.vectorstore.similarity_search_with_score(
|
| 192 |
+
query,
|
| 193 |
+
k=top_k * 10 # 獲取更多結果
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# 在 Python 中進行過濾
|
| 197 |
+
filtered_results = []
|
| 198 |
+
for doc, distance_score in results_with_scores:
|
| 199 |
+
metadata = doc.metadata
|
| 200 |
+
matches = True
|
| 201 |
+
|
| 202 |
+
for key, value in metadata_filter.items():
|
| 203 |
+
doc_value = metadata.get(key)
|
| 204 |
+
|
| 205 |
+
# 檢查是否匹配
|
| 206 |
+
if isinstance(value, dict):
|
| 207 |
+
# 支援運算符格式(例如 {"$eq": "value"})
|
| 208 |
+
if "$eq" in value:
|
| 209 |
+
if doc_value != value["$eq"]:
|
| 210 |
+
matches = False
|
| 211 |
+
break
|
| 212 |
+
else:
|
| 213 |
+
# 其他運算符可以在此擴展
|
| 214 |
+
matches = False
|
| 215 |
+
break
|
| 216 |
+
elif isinstance(value, str) and isinstance(doc_value, str):
|
| 217 |
+
# 字串匹配:支援部分匹配(包含)
|
| 218 |
+
if value.lower() not in doc_value.lower():
|
| 219 |
+
matches = False
|
| 220 |
+
break
|
| 221 |
+
else:
|
| 222 |
+
# 其他類型使用精確匹配
|
| 223 |
+
if doc_value != value:
|
| 224 |
+
matches = False
|
| 225 |
+
break
|
| 226 |
+
|
| 227 |
+
if matches:
|
| 228 |
+
filtered_results.append((doc, distance_score))
|
| 229 |
+
|
| 230 |
+
# 只保留前 top_k 個結果
|
| 231 |
+
results_with_scores = filtered_results[:top_k]
|
| 232 |
+
else:
|
| 233 |
+
# 沒有過濾條件,直接獲取結果
|
| 234 |
+
results_with_scores = self.vectorstore.similarity_search_with_score(
|
| 235 |
+
query,
|
| 236 |
+
k=top_k
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# 構建結果並轉換分數
|
| 240 |
+
results = []
|
| 241 |
+
for doc, distance_score in results_with_scores:
|
| 242 |
+
# 因為 embedding 已正規化,L2 距離的平方為 2 - 2 * cos_sim
|
| 243 |
+
# -> cos_sim = 1 - (distance^2 / 2)
|
| 244 |
+
# 分數範圍在 [0, 1] 之間,越高越相似
|
| 245 |
+
similarity_score = 1 - (distance_score**2 / 2)
|
| 246 |
+
|
| 247 |
+
results.append({
|
| 248 |
+
"content": doc.page_content,
|
| 249 |
+
"metadata": doc.metadata,
|
| 250 |
+
"score": float(similarity_score),
|
| 251 |
+
})
|
| 252 |
+
|
| 253 |
+
return results
|
| 254 |
+
|
src/step_back_rag.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Step-back Prompting 雙軌 RAG:結合具體事實與抽象原理
|
| 3 |
+
使用 Step-back Prompting 技術,同時檢索具體事實和抽象原理,提升回答質量
|
| 4 |
+
"""
|
| 5 |
+
from typing import List, Dict, Optional
|
| 6 |
+
from .retrievers.reranker import RAGPipeline
|
| 7 |
+
from .retrievers.vector_retriever import VectorRetriever
|
| 8 |
+
from .prompt_formatter import PromptFormatter
|
| 9 |
+
from .llm_integration import OllamaLLM
|
| 10 |
+
import time
|
| 11 |
+
import logging
|
| 12 |
+
import hashlib
|
| 13 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class StepBackRAG:
|
| 19 |
+
"""使用 Step-back Prompting 的雙軌 RAG 系統"""
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
rag_pipeline: RAGPipeline,
|
| 24 |
+
vector_retriever: VectorRetriever,
|
| 25 |
+
llm: OllamaLLM,
|
| 26 |
+
step_back_temperature: float = 0.3, # 生成抽象問題時使用較低溫度
|
| 27 |
+
answer_temperature: float = 0.7,
|
| 28 |
+
enable_parallel: bool = True
|
| 29 |
+
):
|
| 30 |
+
"""
|
| 31 |
+
初始化 Step-back RAG
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
rag_pipeline: RAG 管線實例(用於最終答案生成)
|
| 35 |
+
vector_retriever: 向量檢索器
|
| 36 |
+
llm: LLM 實例
|
| 37 |
+
step_back_temperature: 生成抽象問題的溫度(較低,更穩定)
|
| 38 |
+
answer_temperature: 生成答案的溫度
|
| 39 |
+
enable_parallel: 是否並行執行雙軌檢索
|
| 40 |
+
"""
|
| 41 |
+
self.rag_pipeline = rag_pipeline
|
| 42 |
+
self.vector_retriever = vector_retriever
|
| 43 |
+
self.llm = llm
|
| 44 |
+
self.step_back_temperature = step_back_temperature
|
| 45 |
+
self.answer_temperature = answer_temperature
|
| 46 |
+
self.enable_parallel = enable_parallel
|
| 47 |
+
|
| 48 |
+
def _generate_step_back_question(self, question: str) -> str:
|
| 49 |
+
"""
|
| 50 |
+
生成 Step-back 抽象問題
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
question: 原始具體問題
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
抽象問題
|
| 57 |
+
"""
|
| 58 |
+
is_chinese = PromptFormatter.detect_language(question) == "zh"
|
| 59 |
+
|
| 60 |
+
if is_chinese:
|
| 61 |
+
prompt = f"""你是一個資深專家。請將以下具體問題轉換為一個更抽象、更基礎的原理性問題。
|
| 62 |
+
這個抽象問題應該幫助理解該領域的基礎概念和原理,而不是直接回答具體問題。
|
| 63 |
+
|
| 64 |
+
具體問題: {question}
|
| 65 |
+
|
| 66 |
+
請生成一個抽象問題,用於檢索相關的原理和背景知識:
|
| 67 |
+
"""
|
| 68 |
+
else:
|
| 69 |
+
prompt = f"""You are a senior expert. Please convert the following specific question into a more abstract, fundamental question about principles and concepts.
|
| 70 |
+
This abstract question should help understand the basic concepts and principles in this field, rather than directly answering the specific question.
|
| 71 |
+
|
| 72 |
+
Specific question: {question}
|
| 73 |
+
|
| 74 |
+
Please generate an abstract question for retrieving relevant principles and background knowledge:
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
try:
|
| 78 |
+
abstract_question = self.llm.generate(
|
| 79 |
+
prompt=prompt,
|
| 80 |
+
temperature=self.step_back_temperature,
|
| 81 |
+
max_tokens=200
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
abstract_question = abstract_question.strip()
|
| 85 |
+
|
| 86 |
+
if not abstract_question:
|
| 87 |
+
logger.warning("⚠️ 生成的抽象問題為空,使用原始問題")
|
| 88 |
+
return question
|
| 89 |
+
|
| 90 |
+
logger.info(f"✅ 生成抽象問題: '{abstract_question}'")
|
| 91 |
+
return abstract_question
|
| 92 |
+
|
| 93 |
+
except Exception as e:
|
| 94 |
+
logger.error(f"⚠️ 生成抽象問題時出錯: {e}")
|
| 95 |
+
return question
|
| 96 |
+
|
| 97 |
+
def _get_doc_id(self, doc: Dict) -> str:
|
| 98 |
+
"""
|
| 99 |
+
生成文檔的唯一標識符
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
doc: 文檔字典
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
唯一 ID
|
| 106 |
+
"""
|
| 107 |
+
metadata = doc.get("metadata", {})
|
| 108 |
+
content = doc.get("content", "")
|
| 109 |
+
|
| 110 |
+
if "arxiv_id" in metadata and "chunk_index" in metadata:
|
| 111 |
+
return f"{metadata['arxiv_id']}_{metadata['chunk_index']}"
|
| 112 |
+
elif "file_path" in metadata and "chunk_index" in metadata:
|
| 113 |
+
return f"{metadata['file_path']}_{metadata['chunk_index']}"
|
| 114 |
+
else:
|
| 115 |
+
content_hash = hashlib.md5(content.encode()).hexdigest()[:16]
|
| 116 |
+
return f"doc_{content_hash}"
|
| 117 |
+
|
| 118 |
+
def _retrieve_direct(self, question: str, top_k: int, metadata_filter: Optional[Dict] = None) -> List[Dict]:
|
| 119 |
+
"""直接檢索原始問題(具體事實)"""
|
| 120 |
+
return self.vector_retriever.retrieve(
|
| 121 |
+
query=question,
|
| 122 |
+
top_k=top_k,
|
| 123 |
+
metadata_filter=metadata_filter
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
def _retrieve_step_back(self, question: str, top_k: int, metadata_filter: Optional[Dict] = None) -> tuple:
|
| 127 |
+
"""Step-back 檢索(抽象原理)"""
|
| 128 |
+
abstract_question = self._generate_step_back_question(question)
|
| 129 |
+
results = self.vector_retriever.retrieve(
|
| 130 |
+
query=abstract_question,
|
| 131 |
+
top_k=top_k,
|
| 132 |
+
metadata_filter=metadata_filter
|
| 133 |
+
)
|
| 134 |
+
return results, abstract_question
|
| 135 |
+
|
| 136 |
+
def query(
|
| 137 |
+
self,
|
| 138 |
+
question: str,
|
| 139 |
+
top_k: int = 5,
|
| 140 |
+
metadata_filter: Optional[Dict] = None,
|
| 141 |
+
return_abstract_question: bool = False
|
| 142 |
+
) -> Dict:
|
| 143 |
+
"""
|
| 144 |
+
執行雙軌檢索(不生成答案)
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
question: 原始問題
|
| 148 |
+
top_k: 每軌返回的結果數量
|
| 149 |
+
metadata_filter: 可選的 metadata 過濾條件
|
| 150 |
+
return_abstract_question: 是否返回抽象問題
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
包含雙軌檢索結果的字典
|
| 154 |
+
"""
|
| 155 |
+
start_time = time.time()
|
| 156 |
+
|
| 157 |
+
if self.enable_parallel:
|
| 158 |
+
# 並行執行雙軌檢索
|
| 159 |
+
logger.info(f"🔄 並行執行雙軌檢索: '{question}'")
|
| 160 |
+
with ThreadPoolExecutor(max_workers=2) as executor:
|
| 161 |
+
direct_future = executor.submit(
|
| 162 |
+
self._retrieve_direct, question, top_k, metadata_filter
|
| 163 |
+
)
|
| 164 |
+
step_back_future = executor.submit(
|
| 165 |
+
self._retrieve_step_back, question, top_k, metadata_filter
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
specific_results = direct_future.result()
|
| 169 |
+
abstract_results, abstract_question = step_back_future.result()
|
| 170 |
+
else:
|
| 171 |
+
# 串行執行
|
| 172 |
+
logger.info(f"🔄 串行執行雙軌檢索: '{question}'")
|
| 173 |
+
specific_results = self._retrieve_direct(question, top_k, metadata_filter)
|
| 174 |
+
abstract_results, abstract_question = self._retrieve_step_back(question, top_k, metadata_filter)
|
| 175 |
+
|
| 176 |
+
elapsed_time = time.time() - start_time
|
| 177 |
+
logger.info(
|
| 178 |
+
f"✅ 雙軌檢索完成(耗時: {elapsed_time:.2f}s)\n"
|
| 179 |
+
f" 具體事實: {len(specific_results)} 個結果\n"
|
| 180 |
+
f" 抽象原理: {len(abstract_results)} 個結果"
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
return {
|
| 184 |
+
"specific_context": specific_results,
|
| 185 |
+
"abstract_context": abstract_results,
|
| 186 |
+
"abstract_question": abstract_question if return_abstract_question else None,
|
| 187 |
+
"question": question,
|
| 188 |
+
"elapsed_time": elapsed_time
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
def generate_answer(
|
| 192 |
+
self,
|
| 193 |
+
question: str,
|
| 194 |
+
formatter: PromptFormatter,
|
| 195 |
+
top_k: int = 5,
|
| 196 |
+
metadata_filter: Optional[Dict] = None,
|
| 197 |
+
document_type: str = "general",
|
| 198 |
+
return_abstract_question: bool = False
|
| 199 |
+
) -> Dict:
|
| 200 |
+
"""
|
| 201 |
+
完整的 Step-back RAG 流程:雙軌檢索 -> 生成答案
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
question: 原始問題
|
| 205 |
+
formatter: Prompt 格式化器
|
| 206 |
+
top_k: 每軌用於生成答案的文檔數量
|
| 207 |
+
metadata_filter: 可選的 metadata 過濾條件
|
| 208 |
+
document_type: 文檔類型 ("paper", "cv", "general")
|
| 209 |
+
return_abstract_question: 是否返回抽象問題
|
| 210 |
+
|
| 211 |
+
Returns:
|
| 212 |
+
包含檢索結果、生成的答案和統計資訊的字典
|
| 213 |
+
"""
|
| 214 |
+
start_time = time.time()
|
| 215 |
+
|
| 216 |
+
# 第一步:雙軌檢索
|
| 217 |
+
retrieval_result = self.query(
|
| 218 |
+
question=question,
|
| 219 |
+
top_k=top_k,
|
| 220 |
+
metadata_filter=metadata_filter,
|
| 221 |
+
return_abstract_question=return_abstract_question
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
specific_results = retrieval_result["specific_context"]
|
| 225 |
+
abstract_results = retrieval_result["abstract_context"]
|
| 226 |
+
|
| 227 |
+
if not specific_results and not abstract_results:
|
| 228 |
+
return {
|
| 229 |
+
**retrieval_result,
|
| 230 |
+
"answer": "抱歉,未找到相關文檔來回答此問題。",
|
| 231 |
+
"formatted_context": None,
|
| 232 |
+
"answer_time": 0.0,
|
| 233 |
+
"total_time": retrieval_result["elapsed_time"]
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
# 第二步:格式化雙軌上下文
|
| 237 |
+
specific_context = formatter.format_context(
|
| 238 |
+
specific_results,
|
| 239 |
+
document_type=document_type
|
| 240 |
+
) if specific_results else "未找到相關的具體事實資料。"
|
| 241 |
+
|
| 242 |
+
abstract_context = formatter.format_context(
|
| 243 |
+
abstract_results,
|
| 244 |
+
document_type=document_type
|
| 245 |
+
) if abstract_results else "未找到相關的基礎原理資料。"
|
| 246 |
+
|
| 247 |
+
# 第三步:創建融合提示詞(關鍵步驟)
|
| 248 |
+
is_chinese = PromptFormatter.detect_language(question) == "zh"
|
| 249 |
+
|
| 250 |
+
if is_chinese:
|
| 251 |
+
final_prompt = f"""你是一個資深專家。請結合以下兩類資訊來回答使用者的具體問題。
|
| 252 |
+
|
| 253 |
+
【基礎原理與背景】
|
| 254 |
+
{abstract_context}
|
| 255 |
+
|
| 256 |
+
【具體事實資料】
|
| 257 |
+
{specific_context}
|
| 258 |
+
|
| 259 |
+
使用者問題:{question}
|
| 260 |
+
|
| 261 |
+
請根據原理推導並結合事實,給出一個專業且具備邏輯的回答:
|
| 262 |
+
"""
|
| 263 |
+
else:
|
| 264 |
+
final_prompt = f"""You are a senior expert. Please answer the user's specific question by combining the following two types of information.
|
| 265 |
+
|
| 266 |
+
【Fundamental Principles and Background】
|
| 267 |
+
{abstract_context}
|
| 268 |
+
|
| 269 |
+
【Specific Facts and Data】
|
| 270 |
+
{specific_context}
|
| 271 |
+
|
| 272 |
+
User question: {question}
|
| 273 |
+
|
| 274 |
+
Please provide a professional and logical answer based on principles and facts:
|
| 275 |
+
"""
|
| 276 |
+
|
| 277 |
+
# 第四步:生成回答
|
| 278 |
+
logger.info("🤖 生成回答中...")
|
| 279 |
+
answer_start = time.time()
|
| 280 |
+
try:
|
| 281 |
+
answer = self.llm.generate(
|
| 282 |
+
prompt=final_prompt,
|
| 283 |
+
temperature=self.answer_temperature,
|
| 284 |
+
max_tokens=2048
|
| 285 |
+
)
|
| 286 |
+
answer_time = time.time() - answer_start
|
| 287 |
+
logger.info(f"✅ 回答生成完成(耗時: {answer_time:.2f}s)")
|
| 288 |
+
except Exception as e:
|
| 289 |
+
logger.error(f"❌ 生成回答時出錯: {e}")
|
| 290 |
+
answer = f"生成回答時出錯: {e}"
|
| 291 |
+
answer_time = time.time() - answer_start
|
| 292 |
+
|
| 293 |
+
total_time = time.time() - start_time
|
| 294 |
+
|
| 295 |
+
return {
|
| 296 |
+
**retrieval_result,
|
| 297 |
+
"answer": answer,
|
| 298 |
+
"formatted_context": {
|
| 299 |
+
"specific": specific_context,
|
| 300 |
+
"abstract": abstract_context
|
| 301 |
+
},
|
| 302 |
+
"answer_time": answer_time,
|
| 303 |
+
"total_time": total_time
|
| 304 |
+
}
|
| 305 |
+
|
src/subquery_rag.py
ADDED
|
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sub-query Decomposition RAG:將複雜問題拆解成子問題後檢索
|
| 3 |
+
"""
|
| 4 |
+
from typing import List, Dict, Optional
|
| 5 |
+
from .retrievers.reranker import RAGPipeline
|
| 6 |
+
from .prompt_formatter import PromptFormatter
|
| 7 |
+
from .llm_integration import OllamaLLM
|
| 8 |
+
import hashlib
|
| 9 |
+
import time
|
| 10 |
+
import logging
|
| 11 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SubQueryDecompositionRAG:
|
| 17 |
+
"""使用子問題拆解的 RAG 系統"""
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
rag_pipeline: RAGPipeline,
|
| 22 |
+
llm: OllamaLLM,
|
| 23 |
+
max_sub_queries: int = 3,
|
| 24 |
+
top_k_per_subquery: int = 5,
|
| 25 |
+
enable_parallel: bool = True
|
| 26 |
+
):
|
| 27 |
+
"""
|
| 28 |
+
初始化 Sub-query Decomposition RAG
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
rag_pipeline: 現有的 RAG 管線實例
|
| 32 |
+
llm: LLM 實例(用於生成子問題)
|
| 33 |
+
max_sub_queries: 最多生成的子問題數量
|
| 34 |
+
top_k_per_subquery: 每個子問題檢索的結果數量
|
| 35 |
+
enable_parallel: 是否並行處理子查詢
|
| 36 |
+
"""
|
| 37 |
+
self.rag_pipeline = rag_pipeline
|
| 38 |
+
self.llm = llm
|
| 39 |
+
self.max_sub_queries = max_sub_queries
|
| 40 |
+
self.top_k_per_subquery = top_k_per_subquery
|
| 41 |
+
self.enable_parallel = enable_parallel
|
| 42 |
+
|
| 43 |
+
def _generate_sub_queries(self, question: str) -> List[str]:
|
| 44 |
+
"""
|
| 45 |
+
將原始問題拆解成子問題
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
question: 原始問題
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
子問題列表
|
| 52 |
+
"""
|
| 53 |
+
# 檢測語言
|
| 54 |
+
is_chinese = PromptFormatter.detect_language(question) == "zh"
|
| 55 |
+
|
| 56 |
+
if is_chinese:
|
| 57 |
+
prompt = f"""你是一個專業助理。請將以下原始問題拆解成最多 {self.max_sub_queries} 個具體的子問題,以便進行資料搜尋。
|
| 58 |
+
每個子問題應專注於原始問題的一個特定面向。請以換行符號分隔問題。
|
| 59 |
+
|
| 60 |
+
原始問題: {question}
|
| 61 |
+
|
| 62 |
+
子問題清單:"""
|
| 63 |
+
else:
|
| 64 |
+
prompt = f"""You are a professional assistant. Please decompose the following original question into at most {self.max_sub_queries} specific sub-questions for information retrieval.
|
| 65 |
+
Each sub-question should focus on a specific aspect of the original question. Please separate questions with newlines.
|
| 66 |
+
|
| 67 |
+
Original question: {question}
|
| 68 |
+
|
| 69 |
+
Sub-question list:"""
|
| 70 |
+
|
| 71 |
+
try:
|
| 72 |
+
response = self.llm.generate(
|
| 73 |
+
prompt=prompt,
|
| 74 |
+
temperature=0.3, # 降低溫度以獲得更穩定的結果
|
| 75 |
+
max_tokens=500
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# 解析子問題
|
| 79 |
+
sub_queries = [
|
| 80 |
+
q.strip()
|
| 81 |
+
for q in response.strip().split("\n")
|
| 82 |
+
if q.strip() and not q.strip().startswith("#")
|
| 83 |
+
]
|
| 84 |
+
|
| 85 |
+
# 移除編號前綴(如 "1. ", "1) " 等)
|
| 86 |
+
cleaned_queries = []
|
| 87 |
+
for q in sub_queries:
|
| 88 |
+
# 移除開頭的編號
|
| 89 |
+
q = q.lstrip("0123456789. )")
|
| 90 |
+
q = q.strip()
|
| 91 |
+
if q:
|
| 92 |
+
cleaned_queries.append(q)
|
| 93 |
+
|
| 94 |
+
# 限制數量
|
| 95 |
+
cleaned_queries = cleaned_queries[:self.max_sub_queries]
|
| 96 |
+
|
| 97 |
+
# 如果沒有生成子問題,使用原始問題
|
| 98 |
+
if not cleaned_queries:
|
| 99 |
+
logger.warning("⚠️ 未生成子問題,使用原始問題")
|
| 100 |
+
cleaned_queries = [question]
|
| 101 |
+
|
| 102 |
+
return cleaned_queries
|
| 103 |
+
|
| 104 |
+
except Exception as e:
|
| 105 |
+
logger.error(f"⚠️ 生成子問題時出錯: {e}")
|
| 106 |
+
# 回退到原始問題
|
| 107 |
+
return [question]
|
| 108 |
+
|
| 109 |
+
def _get_doc_id(self, doc: Dict) -> str:
|
| 110 |
+
"""
|
| 111 |
+
生成文檔的唯一標識符
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
doc: 文檔字典
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
唯一 ID
|
| 118 |
+
"""
|
| 119 |
+
metadata = doc.get("metadata", {})
|
| 120 |
+
content = doc.get("content", "")
|
| 121 |
+
|
| 122 |
+
# 使用 metadata 中的唯一標識(如果有的話)
|
| 123 |
+
if "arxiv_id" in metadata and "chunk_index" in metadata:
|
| 124 |
+
return f"{metadata['arxiv_id']}_{metadata['chunk_index']}"
|
| 125 |
+
elif "file_path" in metadata and "chunk_index" in metadata:
|
| 126 |
+
return f"{metadata['file_path']}_{metadata['chunk_index']}"
|
| 127 |
+
else:
|
| 128 |
+
# 回退到內容的 hash
|
| 129 |
+
content_hash = hashlib.md5(content.encode()).hexdigest()[:16]
|
| 130 |
+
return f"doc_{content_hash}"
|
| 131 |
+
|
| 132 |
+
def _retrieve_for_subquery(
|
| 133 |
+
self,
|
| 134 |
+
sub_query: str,
|
| 135 |
+
metadata_filter: Optional[Dict] = None
|
| 136 |
+
) -> List[Dict]:
|
| 137 |
+
"""
|
| 138 |
+
針對單個子問題進行檢索
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
sub_query: 子問題
|
| 142 |
+
metadata_filter: 可選的 metadata 過濾條件
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
檢索結果列表
|
| 146 |
+
"""
|
| 147 |
+
try:
|
| 148 |
+
results = self.rag_pipeline.query(
|
| 149 |
+
text=sub_query,
|
| 150 |
+
top_k=self.top_k_per_subquery,
|
| 151 |
+
metadata_filter=metadata_filter,
|
| 152 |
+
enable_rerank=True
|
| 153 |
+
)
|
| 154 |
+
return results
|
| 155 |
+
except Exception as e:
|
| 156 |
+
logger.error(f"⚠️ 檢索子問題 '{sub_query}' 時出錯: {e}")
|
| 157 |
+
return []
|
| 158 |
+
|
| 159 |
+
def _get_unique_documents(
|
| 160 |
+
self,
|
| 161 |
+
sub_queries: List[str],
|
| 162 |
+
metadata_filter: Optional[Dict] = None
|
| 163 |
+
) -> List[Dict]:
|
| 164 |
+
"""
|
| 165 |
+
針對所有子問題進行檢索,並移除重複的檔案
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
sub_queries: 子問題列表
|
| 169 |
+
metadata_filter: 可選的 metadata 過濾條件
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
去重後的文檔列表
|
| 173 |
+
"""
|
| 174 |
+
unique_docs = {}
|
| 175 |
+
|
| 176 |
+
if self.enable_parallel and len(sub_queries) > 1:
|
| 177 |
+
# 並行處理子查詢
|
| 178 |
+
logger.info(f"🔄 並行處理 {len(sub_queries)} 個子查詢...")
|
| 179 |
+
with ThreadPoolExecutor(max_workers=min(len(sub_queries), 5)) as executor:
|
| 180 |
+
future_to_query = {
|
| 181 |
+
executor.submit(self._retrieve_for_subquery, q, metadata_filter): q
|
| 182 |
+
for q in sub_queries
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
for future in as_completed(future_to_query):
|
| 186 |
+
sub_query = future_to_query[future]
|
| 187 |
+
try:
|
| 188 |
+
docs = future.result()
|
| 189 |
+
logger.debug(f"✅ 子問題 '{sub_query}' 找到 {len(docs)} 個結果")
|
| 190 |
+
for doc in docs:
|
| 191 |
+
doc_id = self._get_doc_id(doc)
|
| 192 |
+
if doc_id not in unique_docs:
|
| 193 |
+
unique_docs[doc_id] = doc
|
| 194 |
+
else:
|
| 195 |
+
# 如果已存在,保留分數更高的
|
| 196 |
+
existing_score = unique_docs[doc_id].get(
|
| 197 |
+
'rerank_score',
|
| 198 |
+
unique_docs[doc_id].get('hybrid_score', 0)
|
| 199 |
+
)
|
| 200 |
+
new_score = doc.get(
|
| 201 |
+
'rerank_score',
|
| 202 |
+
doc.get('hybrid_score', 0)
|
| 203 |
+
)
|
| 204 |
+
if new_score > existing_score:
|
| 205 |
+
unique_docs[doc_id] = doc
|
| 206 |
+
except Exception as e:
|
| 207 |
+
logger.error(f"⚠️ 處理子問題 '{sub_query}' 時出錯: {e}")
|
| 208 |
+
else:
|
| 209 |
+
# 串行處理
|
| 210 |
+
logger.info(f"🔄 串行處理 {len(sub_queries)} 個子查詢...")
|
| 211 |
+
for sub_query in sub_queries:
|
| 212 |
+
docs = self._retrieve_for_subquery(sub_query, metadata_filter)
|
| 213 |
+
logger.debug(f"✅ 子問題 '{sub_query}' 找到 {len(docs)} 個結果")
|
| 214 |
+
for doc in docs:
|
| 215 |
+
doc_id = self._get_doc_id(doc)
|
| 216 |
+
if doc_id not in unique_docs:
|
| 217 |
+
unique_docs[doc_id] = doc
|
| 218 |
+
else:
|
| 219 |
+
# 保留分數更高的
|
| 220 |
+
existing_score = unique_docs[doc_id].get(
|
| 221 |
+
'rerank_score',
|
| 222 |
+
unique_docs[doc_id].get('hybrid_score', 0)
|
| 223 |
+
)
|
| 224 |
+
new_score = doc.get(
|
| 225 |
+
'rerank_score',
|
| 226 |
+
doc.get('hybrid_score', 0)
|
| 227 |
+
)
|
| 228 |
+
if new_score > existing_score:
|
| 229 |
+
unique_docs[doc_id] = doc
|
| 230 |
+
|
| 231 |
+
# 按分數排序
|
| 232 |
+
result_list = list(unique_docs.values())
|
| 233 |
+
result_list.sort(
|
| 234 |
+
key=lambda x: x.get('rerank_score', x.get('hybrid_score', 0)),
|
| 235 |
+
reverse=True
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
return result_list
|
| 239 |
+
|
| 240 |
+
def query(
|
| 241 |
+
self,
|
| 242 |
+
question: str,
|
| 243 |
+
top_k: int = 5,
|
| 244 |
+
metadata_filter: Optional[Dict] = None,
|
| 245 |
+
return_sub_queries: bool = False
|
| 246 |
+
) -> Dict:
|
| 247 |
+
"""
|
| 248 |
+
執行 Sub-query Decomposition RAG 查詢
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
question: 原始問題
|
| 252 |
+
top_k: 返回前 k 個結果
|
| 253 |
+
metadata_filter: 可選的 metadata 過濾條件
|
| 254 |
+
return_sub_queries: 是否在結果中包含子問題列表
|
| 255 |
+
|
| 256 |
+
Returns:
|
| 257 |
+
包含檢索結果和統計資訊的字典
|
| 258 |
+
"""
|
| 259 |
+
start_time = time.time()
|
| 260 |
+
|
| 261 |
+
# 第一步:產生子問題
|
| 262 |
+
logger.info(f"🔍 拆解問題: '{question}'")
|
| 263 |
+
sub_queries = self._generate_sub_queries(question)
|
| 264 |
+
logger.info(f"✅ 生成 {len(sub_queries)} 個子問題:")
|
| 265 |
+
for i, sq in enumerate(sub_queries, 1):
|
| 266 |
+
logger.info(f" {i}. {sq}")
|
| 267 |
+
|
| 268 |
+
# 第二步:檢索並去重
|
| 269 |
+
logger.info(f"📚 檢索相關文檔...")
|
| 270 |
+
docs = self._get_unique_documents(sub_queries, metadata_filter)
|
| 271 |
+
logger.info(f"✅ 找到 {len(docs)} 個唯一文檔(去重後)")
|
| 272 |
+
|
| 273 |
+
# 第三步:返回前 top_k 個結果
|
| 274 |
+
final_results = docs[:top_k]
|
| 275 |
+
|
| 276 |
+
elapsed_time = time.time() - start_time
|
| 277 |
+
|
| 278 |
+
result = {
|
| 279 |
+
"results": final_results,
|
| 280 |
+
"total_docs_found": len(docs),
|
| 281 |
+
"sub_queries": sub_queries if return_sub_queries else None,
|
| 282 |
+
"elapsed_time": elapsed_time
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
return result
|
| 286 |
+
|
| 287 |
+
def generate_answer(
|
| 288 |
+
self,
|
| 289 |
+
question: str,
|
| 290 |
+
formatter: PromptFormatter,
|
| 291 |
+
top_k: int = 5,
|
| 292 |
+
metadata_filter: Optional[Dict] = None,
|
| 293 |
+
document_type: str = "general",
|
| 294 |
+
return_sub_queries: bool = False
|
| 295 |
+
) -> Dict:
|
| 296 |
+
"""
|
| 297 |
+
完整的 Sub-query Decomposition RAG 流程:檢索 + 生成答案
|
| 298 |
+
|
| 299 |
+
Args:
|
| 300 |
+
question: 原始問題
|
| 301 |
+
formatter: Prompt 格式化器
|
| 302 |
+
top_k: 返回前 k 個結果用於生成答案
|
| 303 |
+
metadata_filter: 可選的 metadata 過濾條件
|
| 304 |
+
document_type: 文檔類型 ("paper", "cv", "general")
|
| 305 |
+
return_sub_queries: 是否在結果中包含子問題列表
|
| 306 |
+
|
| 307 |
+
Returns:
|
| 308 |
+
包含檢索結果、生成的答案和統計資訊的字典
|
| 309 |
+
"""
|
| 310 |
+
# 檢索
|
| 311 |
+
retrieval_result = self.query(
|
| 312 |
+
question=question,
|
| 313 |
+
top_k=top_k,
|
| 314 |
+
metadata_filter=metadata_filter,
|
| 315 |
+
return_sub_queries=return_sub_queries
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
if not retrieval_result["results"]:
|
| 319 |
+
return {
|
| 320 |
+
**retrieval_result,
|
| 321 |
+
"answer": "抱歉,未找到相關文檔來回答此問題。",
|
| 322 |
+
"formatted_context": None
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
# 格式化上下文
|
| 326 |
+
formatted_context = formatter.format_context(
|
| 327 |
+
retrieval_result["results"],
|
| 328 |
+
document_type=document_type
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
# 創建 prompt
|
| 332 |
+
prompt = formatter.create_prompt(
|
| 333 |
+
question,
|
| 334 |
+
formatted_context,
|
| 335 |
+
document_type=document_type
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
# 生成回答
|
| 339 |
+
logger.info("🤖 生成回答中...")
|
| 340 |
+
answer_start = time.time()
|
| 341 |
+
try:
|
| 342 |
+
answer = self.llm.generate(
|
| 343 |
+
prompt=prompt,
|
| 344 |
+
temperature=0.7,
|
| 345 |
+
max_tokens=2048
|
| 346 |
+
)
|
| 347 |
+
answer_time = time.time() - answer_start
|
| 348 |
+
logger.info(f"✅ 回答生成完成(耗時: {answer_time:.2f}s)")
|
| 349 |
+
except Exception as e:
|
| 350 |
+
logger.error(f"❌ 生成回答時出錯: {e}")
|
| 351 |
+
answer = f"生成回答時出錯: {e}"
|
| 352 |
+
answer_time = time.time() - answer_start
|
| 353 |
+
|
| 354 |
+
return {
|
| 355 |
+
**retrieval_result,
|
| 356 |
+
"answer": answer,
|
| 357 |
+
"formatted_context": formatted_context,
|
| 358 |
+
"answer_time": answer_time,
|
| 359 |
+
"total_time": retrieval_result["elapsed_time"] + answer_time
|
| 360 |
+
}
|
| 361 |
+
|
src/triple_hybrid_rag.py
ADDED
|
@@ -0,0 +1,467 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Triple Hybrid RAG:融合 SubQuery + HyDE + Step-back Prompting
|
| 3 |
+
結合三種技術的優勢,實現最強大的 RAG 系統
|
| 4 |
+
"""
|
| 5 |
+
from typing import List, Dict, Optional
|
| 6 |
+
from .retrievers.reranker import RAGPipeline
|
| 7 |
+
from .retrievers.vector_retriever import VectorRetriever
|
| 8 |
+
from .prompt_formatter import PromptFormatter
|
| 9 |
+
from .llm_integration import OllamaLLM
|
| 10 |
+
import hashlib
|
| 11 |
+
import time
|
| 12 |
+
import logging
|
| 13 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class TripleHybridRAG:
|
| 19 |
+
"""融合 SubQuery + HyDE + Step-back 的三重混合 RAG 系統"""
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
rag_pipeline: RAGPipeline,
|
| 24 |
+
vector_retriever: VectorRetriever,
|
| 25 |
+
llm: OllamaLLM,
|
| 26 |
+
max_sub_queries: int = 3,
|
| 27 |
+
top_k_per_subquery: int = 5,
|
| 28 |
+
hypothetical_length: int = 200,
|
| 29 |
+
temperature_subquery: float = 0.3,
|
| 30 |
+
temperature_hyde: float = 0.7,
|
| 31 |
+
temperature_stepback: float = 0.3,
|
| 32 |
+
answer_temperature: float = 0.7,
|
| 33 |
+
enable_parallel: bool = True
|
| 34 |
+
):
|
| 35 |
+
"""
|
| 36 |
+
初始化三重混合 RAG
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
rag_pipeline: RAG 管線實例
|
| 40 |
+
vector_retriever: 向量檢索器
|
| 41 |
+
llm: LLM 實例
|
| 42 |
+
max_sub_queries: 最多生成的子問題數量
|
| 43 |
+
top_k_per_subquery: 每個子問題檢索的結果數量
|
| 44 |
+
hypothetical_length: 假設性文檔目標長度(字符數)
|
| 45 |
+
temperature_subquery: 生成子問題的溫度(較低,更穩定)
|
| 46 |
+
temperature_hyde: 生成假設性文檔的溫度(較高,更多專業術語)
|
| 47 |
+
temperature_stepback: 生成抽象問題的溫度(較低,更穩定)
|
| 48 |
+
answer_temperature: 生成答案的溫度
|
| 49 |
+
enable_parallel: 是否並行處理
|
| 50 |
+
"""
|
| 51 |
+
self.rag_pipeline = rag_pipeline
|
| 52 |
+
self.vector_retriever = vector_retriever
|
| 53 |
+
self.llm = llm
|
| 54 |
+
self.max_sub_queries = max_sub_queries
|
| 55 |
+
self.top_k_per_subquery = top_k_per_subquery
|
| 56 |
+
self.hypothetical_length = hypothetical_length
|
| 57 |
+
self.temperature_subquery = temperature_subquery
|
| 58 |
+
self.temperature_hyde = temperature_hyde
|
| 59 |
+
self.temperature_stepback = temperature_stepback
|
| 60 |
+
self.answer_temperature = answer_temperature
|
| 61 |
+
self.enable_parallel = enable_parallel
|
| 62 |
+
|
| 63 |
+
def _generate_sub_queries(self, question: str) -> List[str]:
|
| 64 |
+
"""生成子問題(SubQuery)"""
|
| 65 |
+
is_chinese = PromptFormatter.detect_language(question) == "zh"
|
| 66 |
+
|
| 67 |
+
if is_chinese:
|
| 68 |
+
prompt = f"""你是一個專業助理。請將以下原始問題拆解成最多 {self.max_sub_queries} 個具體的子問題,以便進行資料搜尋。
|
| 69 |
+
每個子問題應專注於原始問題的一個特定面向。請以換行符號分隔問題。
|
| 70 |
+
|
| 71 |
+
原始問題: {question}
|
| 72 |
+
|
| 73 |
+
子問題清單:"""
|
| 74 |
+
else:
|
| 75 |
+
prompt = f"""You are a professional assistant. Please decompose the following original question into at most {self.max_sub_queries} specific sub-questions for information retrieval.
|
| 76 |
+
Each sub-question should focus on a specific aspect of the original question. Please separate questions with newlines.
|
| 77 |
+
|
| 78 |
+
Original question: {question}
|
| 79 |
+
|
| 80 |
+
Sub-question list:"""
|
| 81 |
+
|
| 82 |
+
try:
|
| 83 |
+
response = self.llm.generate(
|
| 84 |
+
prompt=prompt,
|
| 85 |
+
temperature=self.temperature_subquery,
|
| 86 |
+
max_tokens=500
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
sub_queries = [
|
| 90 |
+
q.strip()
|
| 91 |
+
for q in response.strip().split("\n")
|
| 92 |
+
if q.strip() and not q.strip().startswith("#")
|
| 93 |
+
]
|
| 94 |
+
|
| 95 |
+
# 移除編號前綴
|
| 96 |
+
cleaned_queries = []
|
| 97 |
+
for q in sub_queries:
|
| 98 |
+
q = q.lstrip("0123456789. )")
|
| 99 |
+
q = q.strip()
|
| 100 |
+
if q:
|
| 101 |
+
cleaned_queries.append(q)
|
| 102 |
+
|
| 103 |
+
cleaned_queries = cleaned_queries[:self.max_sub_queries]
|
| 104 |
+
|
| 105 |
+
if not cleaned_queries:
|
| 106 |
+
logger.warning("⚠️ 未生成子問題,使用原始問題")
|
| 107 |
+
cleaned_queries = [question]
|
| 108 |
+
|
| 109 |
+
return cleaned_queries
|
| 110 |
+
|
| 111 |
+
except Exception as e:
|
| 112 |
+
logger.error(f"⚠️ 生成子問題時出錯: {e}")
|
| 113 |
+
return [question]
|
| 114 |
+
|
| 115 |
+
def _generate_hypothetical_document(self, sub_query: str) -> str:
|
| 116 |
+
"""為子問題生成假設性文檔(HyDE)"""
|
| 117 |
+
is_chinese = PromptFormatter.detect_language(sub_query) == "zh"
|
| 118 |
+
|
| 119 |
+
if is_chinese:
|
| 120 |
+
prompt = f"""請針對以下問題,寫出一段約 {self.hypothetical_length} 字的專業技術檔案內容。
|
| 121 |
+
這段內容應包含該領域常見的專業術語與原理說明,以便用於後續的語義檢索。
|
| 122 |
+
請使用專業的術語和概念,即使你對某些細節不確定,也要包含相關的專業詞彙。
|
| 123 |
+
|
| 124 |
+
問題: {sub_query}
|
| 125 |
+
|
| 126 |
+
專業��術內容:"""
|
| 127 |
+
else:
|
| 128 |
+
prompt = f"""Please write a professional technical document of approximately {self.hypothetical_length} words in response to the following question.
|
| 129 |
+
This content should include common professional terminology and principle explanations in this field, to be used for subsequent semantic retrieval.
|
| 130 |
+
Please use professional terms and concepts, and include relevant professional vocabulary even if you are uncertain about some details.
|
| 131 |
+
|
| 132 |
+
Question: {sub_query}
|
| 133 |
+
|
| 134 |
+
Professional technical content:"""
|
| 135 |
+
|
| 136 |
+
try:
|
| 137 |
+
hypothetical_doc = self.llm.generate(
|
| 138 |
+
prompt=prompt,
|
| 139 |
+
temperature=self.temperature_hyde,
|
| 140 |
+
max_tokens=500
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
hypothetical_doc = hypothetical_doc.strip()
|
| 144 |
+
|
| 145 |
+
if not hypothetical_doc:
|
| 146 |
+
logger.warning(f"⚠️ 子問題 '{sub_query}' 的假設性文檔為空,使用子問題本身")
|
| 147 |
+
return sub_query
|
| 148 |
+
|
| 149 |
+
return hypothetical_doc
|
| 150 |
+
|
| 151 |
+
except Exception as e:
|
| 152 |
+
logger.error(f"⚠️ 生成假設性文檔時出錯: {e}")
|
| 153 |
+
return sub_query
|
| 154 |
+
|
| 155 |
+
def _generate_step_back_question(self, question: str) -> str:
|
| 156 |
+
"""生成 Step-back 抽象問題"""
|
| 157 |
+
is_chinese = PromptFormatter.detect_language(question) == "zh"
|
| 158 |
+
|
| 159 |
+
if is_chinese:
|
| 160 |
+
prompt = f"""你是一個資深專家。請將以下具體問題轉換為一個更抽象、更基礎的原理性問題。
|
| 161 |
+
這個抽象問題應該幫助理解該領域的基礎概念和原理,而不是直接回答具體問題。
|
| 162 |
+
|
| 163 |
+
具體問題: {question}
|
| 164 |
+
|
| 165 |
+
請生成一個抽象問題,用於檢索相關的原理和背景知識:
|
| 166 |
+
"""
|
| 167 |
+
else:
|
| 168 |
+
prompt = f"""You are a senior expert. Please convert the following specific question into a more abstract, fundamental question about principles and concepts.
|
| 169 |
+
This abstract question should help understand the basic concepts and principles in this field, rather than directly answering the specific question.
|
| 170 |
+
|
| 171 |
+
Specific question: {question}
|
| 172 |
+
|
| 173 |
+
Please generate an abstract question for retrieving relevant principles and background knowledge:
|
| 174 |
+
"""
|
| 175 |
+
|
| 176 |
+
try:
|
| 177 |
+
abstract_question = self.llm.generate(
|
| 178 |
+
prompt=prompt,
|
| 179 |
+
temperature=self.temperature_stepback,
|
| 180 |
+
max_tokens=200
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
abstract_question = abstract_question.strip()
|
| 184 |
+
|
| 185 |
+
if not abstract_question:
|
| 186 |
+
logger.warning("⚠️ 生成的抽象問題為空,使用原始問題")
|
| 187 |
+
return question
|
| 188 |
+
|
| 189 |
+
return abstract_question
|
| 190 |
+
|
| 191 |
+
except Exception as e:
|
| 192 |
+
logger.error(f"⚠️ 生成抽象問題時出錯: {e}")
|
| 193 |
+
return question
|
| 194 |
+
|
| 195 |
+
def _get_doc_id(self, doc: Dict) -> str:
|
| 196 |
+
"""生成文檔的唯一標識符"""
|
| 197 |
+
metadata = doc.get("metadata", {})
|
| 198 |
+
content = doc.get("content", "")
|
| 199 |
+
|
| 200 |
+
if "arxiv_id" in metadata and "chunk_index" in metadata:
|
| 201 |
+
return f"{metadata['arxiv_id']}_{metadata['chunk_index']}"
|
| 202 |
+
elif "file_path" in metadata and "chunk_index" in metadata:
|
| 203 |
+
return f"{metadata['file_path']}_{metadata['chunk_index']}"
|
| 204 |
+
else:
|
| 205 |
+
content_hash = hashlib.md5(content.encode()).hexdigest()[:16]
|
| 206 |
+
return f"doc_{content_hash}"
|
| 207 |
+
|
| 208 |
+
def _process_subquery_with_hyde(
|
| 209 |
+
self,
|
| 210 |
+
sub_query: str,
|
| 211 |
+
metadata_filter: Optional[Dict] = None
|
| 212 |
+
) -> tuple:
|
| 213 |
+
"""處理單個子問題:生成假設性文檔並檢索"""
|
| 214 |
+
try:
|
| 215 |
+
hypothetical_doc = self._generate_hypothetical_document(sub_query)
|
| 216 |
+
results = self.vector_retriever.retrieve(
|
| 217 |
+
query=hypothetical_doc,
|
| 218 |
+
top_k=self.top_k_per_subquery,
|
| 219 |
+
metadata_filter=metadata_filter
|
| 220 |
+
)
|
| 221 |
+
return results, hypothetical_doc
|
| 222 |
+
except Exception as e:
|
| 223 |
+
logger.error(f"⚠️ 處理子問題 '{sub_query}' 時出錯: {e}")
|
| 224 |
+
return [], ""
|
| 225 |
+
|
| 226 |
+
def query(
|
| 227 |
+
self,
|
| 228 |
+
question: str,
|
| 229 |
+
top_k: int = 5,
|
| 230 |
+
metadata_filter: Optional[Dict] = None,
|
| 231 |
+
return_sub_queries: bool = False,
|
| 232 |
+
return_hypothetical: bool = False,
|
| 233 |
+
return_abstract_question: bool = False
|
| 234 |
+
) -> Dict:
|
| 235 |
+
"""
|
| 236 |
+
執行三重混合 RAG 檢索
|
| 237 |
+
|
| 238 |
+
流程:
|
| 239 |
+
1. 拆解成子問題(SubQuery)
|
| 240 |
+
2. 對每個子問題生成假設性文檔並檢索(HyDE)
|
| 241 |
+
3. 直接檢索原始問題(具體事實)
|
| 242 |
+
4. 生成抽象問題並檢索(Step-back,抽象原理)
|
| 243 |
+
5. 合併所有結果並去重
|
| 244 |
+
"""
|
| 245 |
+
start_time = time.time()
|
| 246 |
+
|
| 247 |
+
# 第一步:生成子問題
|
| 248 |
+
logger.info(f"🔍 [SubQuery] 拆解問題: '{question}'")
|
| 249 |
+
sub_queries = self._generate_sub_queries(question)
|
| 250 |
+
logger.info(f"✅ 生成 {len(sub_queries)} 個子問題")
|
| 251 |
+
|
| 252 |
+
# 第二步:為每個子問題生成假設性文檔並檢索(HyDE)
|
| 253 |
+
logger.info(f"📚 [HyDE] 為每個子問題生成假設性文檔並檢索...")
|
| 254 |
+
subquery_results = []
|
| 255 |
+
hypothetical_docs = {}
|
| 256 |
+
|
| 257 |
+
if self.enable_parallel and len(sub_queries) > 1:
|
| 258 |
+
with ThreadPoolExecutor(max_workers=min(len(sub_queries), 5)) as executor:
|
| 259 |
+
future_to_query = {
|
| 260 |
+
executor.submit(self._process_subquery_with_hyde, sq, metadata_filter): sq
|
| 261 |
+
for sq in sub_queries
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
for future in as_completed(future_to_query):
|
| 265 |
+
sub_query = future_to_query[future]
|
| 266 |
+
try:
|
| 267 |
+
results, hypo_doc = future.result()
|
| 268 |
+
hypothetical_docs[sub_query] = hypo_doc
|
| 269 |
+
subquery_results.extend(results)
|
| 270 |
+
except Exception as e:
|
| 271 |
+
logger.error(f"⚠️ 處理子問題 '{sub_query}' 時出錯: {e}")
|
| 272 |
+
else:
|
| 273 |
+
for sub_query in sub_queries:
|
| 274 |
+
results, hypo_doc = self._process_subquery_with_hyde(sub_query, metadata_filter)
|
| 275 |
+
hypothetical_docs[sub_query] = hypo_doc
|
| 276 |
+
subquery_results.extend(results)
|
| 277 |
+
|
| 278 |
+
# 第三步:Step-back 雙軌檢索
|
| 279 |
+
logger.info(f"🔍 [Step-back] 執行雙軌檢索...")
|
| 280 |
+
|
| 281 |
+
if self.enable_parallel:
|
| 282 |
+
with ThreadPoolExecutor(max_workers=2) as executor:
|
| 283 |
+
direct_future = executor.submit(
|
| 284 |
+
self.vector_retriever.retrieve,
|
| 285 |
+
question, top_k, metadata_filter
|
| 286 |
+
)
|
| 287 |
+
abstract_question = self._generate_step_back_question(question)
|
| 288 |
+
step_back_future = executor.submit(
|
| 289 |
+
self.vector_retriever.retrieve,
|
| 290 |
+
abstract_question, top_k, metadata_filter
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
specific_results = direct_future.result()
|
| 294 |
+
abstract_results = step_back_future.result()
|
| 295 |
+
else:
|
| 296 |
+
specific_results = self.vector_retriever.retrieve(
|
| 297 |
+
query=question,
|
| 298 |
+
top_k=top_k,
|
| 299 |
+
metadata_filter=metadata_filter
|
| 300 |
+
)
|
| 301 |
+
abstract_question = self._generate_step_back_question(question)
|
| 302 |
+
abstract_results = self.vector_retriever.retrieve(
|
| 303 |
+
query=abstract_question,
|
| 304 |
+
top_k=top_k,
|
| 305 |
+
metadata_filter=metadata_filter
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
# 第四步:合併所有結果並去重
|
| 309 |
+
logger.info(f"🔄 合併並去重所有檢索結果...")
|
| 310 |
+
all_results = subquery_results + specific_results + abstract_results
|
| 311 |
+
unique_docs = {}
|
| 312 |
+
|
| 313 |
+
for doc in all_results:
|
| 314 |
+
doc_id = self._get_doc_id(doc)
|
| 315 |
+
if doc_id not in unique_docs:
|
| 316 |
+
unique_docs[doc_id] = doc
|
| 317 |
+
else:
|
| 318 |
+
# 保留分數更高的
|
| 319 |
+
existing_score = unique_docs[doc_id].get('score', 0)
|
| 320 |
+
new_score = doc.get('score', 0)
|
| 321 |
+
if new_score > existing_score:
|
| 322 |
+
unique_docs[doc_id] = doc
|
| 323 |
+
|
| 324 |
+
# 排序並返回前 top_k
|
| 325 |
+
result_list = list(unique_docs.values())
|
| 326 |
+
result_list.sort(key=lambda x: x.get('score', 0), reverse=True)
|
| 327 |
+
final_results = result_list[:top_k]
|
| 328 |
+
|
| 329 |
+
elapsed_time = time.time() - start_time
|
| 330 |
+
logger.info(
|
| 331 |
+
f"✅ 三重混合檢索完成(耗時: {elapsed_time:.2f}s)\n"
|
| 332 |
+
f" 子問題檢索: {len(subquery_results)} 個結果\n"
|
| 333 |
+
f" 具體事實: {len(specific_results)} 個結果\n"
|
| 334 |
+
f" 抽象原理: {len(abstract_results)} 個結果\n"
|
| 335 |
+
f" 去重後總計: {len(result_list)} 個,返回前 {len(final_results)} 個"
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
return {
|
| 339 |
+
"results": final_results,
|
| 340 |
+
"total_docs_found": len(result_list),
|
| 341 |
+
"sub_queries": sub_queries if return_sub_queries else None,
|
| 342 |
+
"hypothetical_documents": hypothetical_docs if return_hypothetical else None,
|
| 343 |
+
"abstract_question": abstract_question if return_abstract_question else None,
|
| 344 |
+
"subquery_results": subquery_results,
|
| 345 |
+
"specific_context": specific_results,
|
| 346 |
+
"abstract_context": abstract_results,
|
| 347 |
+
"question": question,
|
| 348 |
+
"elapsed_time": elapsed_time
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
def generate_answer(
|
| 352 |
+
self,
|
| 353 |
+
question: str,
|
| 354 |
+
formatter: PromptFormatter,
|
| 355 |
+
top_k: int = 5,
|
| 356 |
+
metadata_filter: Optional[Dict] = None,
|
| 357 |
+
document_type: str = "general",
|
| 358 |
+
return_sub_queries: bool = False,
|
| 359 |
+
return_hypothetical: bool = False,
|
| 360 |
+
return_abstract_question: bool = False
|
| 361 |
+
) -> Dict:
|
| 362 |
+
"""
|
| 363 |
+
完整的三重混合 RAG 流程:檢索 + 生成答案
|
| 364 |
+
"""
|
| 365 |
+
start_time = time.time()
|
| 366 |
+
|
| 367 |
+
# 檢索
|
| 368 |
+
retrieval_result = self.query(
|
| 369 |
+
question=question,
|
| 370 |
+
top_k=top_k,
|
| 371 |
+
metadata_filter=metadata_filter,
|
| 372 |
+
return_sub_queries=return_sub_queries,
|
| 373 |
+
return_hypothetical=return_hypothetical,
|
| 374 |
+
return_abstract_question=return_abstract_question
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
if not retrieval_result["results"]:
|
| 378 |
+
return {
|
| 379 |
+
**retrieval_result,
|
| 380 |
+
"answer": "抱歉,未找到相關文檔來回答此問題。",
|
| 381 |
+
"formatted_context": None,
|
| 382 |
+
"answer_time": 0.0,
|
| 383 |
+
"total_time": retrieval_result["elapsed_time"]
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
# 格式化三類上下文
|
| 387 |
+
subquery_context = formatter.format_context(
|
| 388 |
+
retrieval_result["subquery_results"][:top_k],
|
| 389 |
+
document_type=document_type
|
| 390 |
+
) if retrieval_result.get("subquery_results") else "未找到相關的子問題檢索結果。"
|
| 391 |
+
|
| 392 |
+
specific_context = formatter.format_context(
|
| 393 |
+
retrieval_result["specific_context"],
|
| 394 |
+
document_type=document_type
|
| 395 |
+
) if retrieval_result.get("specific_context") else "未找到相關的具體事實資料。"
|
| 396 |
+
|
| 397 |
+
abstract_context = formatter.format_context(
|
| 398 |
+
retrieval_result["abstract_context"],
|
| 399 |
+
document_type=document_type
|
| 400 |
+
) if retrieval_result.get("abstract_context") else "未找到相關的基礎原理資料。"
|
| 401 |
+
|
| 402 |
+
# 創建融合提示詞(關鍵步驟)
|
| 403 |
+
is_chinese = PromptFormatter.detect_language(question) == "zh"
|
| 404 |
+
|
| 405 |
+
if is_chinese:
|
| 406 |
+
final_prompt = f"""你是一個資深專家。請結合以下三類資訊來回答使用者的具體問題。
|
| 407 |
+
|
| 408 |
+
【基礎原理與背景】(來自 Step-back 抽象問題檢索)
|
| 409 |
+
{abstract_context}
|
| 410 |
+
|
| 411 |
+
【具體事實資料】(來自直接問題檢索)
|
| 412 |
+
{specific_context}
|
| 413 |
+
|
| 414 |
+
【子問題相關資料】(來自 SubQuery + HyDE 檢索)
|
| 415 |
+
{subquery_context}
|
| 416 |
+
|
| 417 |
+
使用者問題:{question}
|
| 418 |
+
|
| 419 |
+
請根據原理推導、結合具體事實,並參考子問題的相關資料,給出一個專業、全面且具備邏輯的回答:
|
| 420 |
+
"""
|
| 421 |
+
else:
|
| 422 |
+
final_prompt = f"""You are a senior expert. Please answer the user's specific question by combining the following three types of information.
|
| 423 |
+
|
| 424 |
+
【Fundamental Principles and Background】(from Step-back abstract question retrieval)
|
| 425 |
+
{abstract_context}
|
| 426 |
+
|
| 427 |
+
【Specific Facts and Data】(from direct question retrieval)
|
| 428 |
+
{specific_context}
|
| 429 |
+
|
| 430 |
+
【Sub-question Related Information】(from SubQuery + HyDE retrieval)
|
| 431 |
+
{subquery_context}
|
| 432 |
+
|
| 433 |
+
User question: {question}
|
| 434 |
+
|
| 435 |
+
Please provide a professional, comprehensive, and logical answer based on principles, facts, and sub-question related information:
|
| 436 |
+
"""
|
| 437 |
+
|
| 438 |
+
# 生成回答
|
| 439 |
+
logger.info("🤖 生成回答中...")
|
| 440 |
+
answer_start = time.time()
|
| 441 |
+
try:
|
| 442 |
+
answer = self.llm.generate(
|
| 443 |
+
prompt=final_prompt,
|
| 444 |
+
temperature=self.answer_temperature,
|
| 445 |
+
max_tokens=2048
|
| 446 |
+
)
|
| 447 |
+
answer_time = time.time() - answer_start
|
| 448 |
+
logger.info(f"✅ 回答生成完成(耗時: {answer_time:.2f}s)")
|
| 449 |
+
except Exception as e:
|
| 450 |
+
logger.error(f"❌ 生成回答時出錯: {e}")
|
| 451 |
+
answer = f"生成回答時出錯: {e}"
|
| 452 |
+
answer_time = time.time() - answer_start
|
| 453 |
+
|
| 454 |
+
total_time = time.time() - start_time
|
| 455 |
+
|
| 456 |
+
return {
|
| 457 |
+
**retrieval_result,
|
| 458 |
+
"answer": answer,
|
| 459 |
+
"formatted_context": {
|
| 460 |
+
"subquery": subquery_context,
|
| 461 |
+
"specific": specific_context,
|
| 462 |
+
"abstract": abstract_context
|
| 463 |
+
},
|
| 464 |
+
"answer_time": answer_time,
|
| 465 |
+
"total_time": total_time
|
| 466 |
+
}
|
| 467 |
+
|
uv.lock
CHANGED
|
@@ -179,6 +179,19 @@ wheels = [
|
|
| 179 |
{ url = "https://files.pythonhosted.org/packages/7f/9c/36c5c37947ebfb8c7f22e0eb6e4d188ee2d53aa3880f3f2744fb894f0cb1/anyio-4.12.0-py3-none-any.whl", hash = "sha256:dad2376a628f98eeca4881fc56cd06affd18f659b17a747d3ff0307ced94b1bb", size = 113362, upload-time = "2025-11-28T23:36:57.897Z" },
|
| 180 |
]
|
| 181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
[[package]]
|
| 183 |
name = "attrs"
|
| 184 |
version = "25.4.0"
|
|
@@ -808,6 +821,9 @@ version = "0.1.0"
|
|
| 808 |
source = { virtual = "." }
|
| 809 |
dependencies = [
|
| 810 |
{ name = "accelerate" },
|
|
|
|
|
|
|
|
|
|
| 811 |
{ name = "einops" },
|
| 812 |
{ name = "fastapi" },
|
| 813 |
{ name = "google-api-python-client" },
|
|
@@ -820,9 +836,12 @@ dependencies = [
|
|
| 820 |
{ name = "langchain" },
|
| 821 |
{ name = "langchain-chroma" },
|
| 822 |
{ name = "langchain-community" },
|
|
|
|
| 823 |
{ name = "langchain-google-genai" },
|
| 824 |
{ name = "langchain-groq" },
|
|
|
|
| 825 |
{ name = "langchain-tavily" },
|
|
|
|
| 826 |
{ name = "langgraph" },
|
| 827 |
{ name = "langserve", extra = ["all"] },
|
| 828 |
{ name = "mcp" },
|
|
@@ -833,6 +852,7 @@ dependencies = [
|
|
| 833 |
{ name = "pillow" },
|
| 834 |
{ name = "pypdf" },
|
| 835 |
{ name = "python-dotenv" },
|
|
|
|
| 836 |
{ name = "sentence-transformers" },
|
| 837 |
{ name = "tavily-python" },
|
| 838 |
{ name = "torch" },
|
|
@@ -844,6 +864,9 @@ dependencies = [
|
|
| 844 |
[package.metadata]
|
| 845 |
requires-dist = [
|
| 846 |
{ name = "accelerate", specifier = ">=1.12.0" },
|
|
|
|
|
|
|
|
|
|
| 847 |
{ name = "einops", specifier = ">=0.8.1" },
|
| 848 |
{ name = "fastapi", specifier = ">=0.124.2" },
|
| 849 |
{ name = "google-api-python-client", specifier = ">=2.187.0" },
|
|
@@ -856,9 +879,12 @@ requires-dist = [
|
|
| 856 |
{ name = "langchain", specifier = ">=1.1.3" },
|
| 857 |
{ name = "langchain-chroma", specifier = ">=1.0.0" },
|
| 858 |
{ name = "langchain-community", specifier = ">=0.4.1" },
|
|
|
|
| 859 |
{ name = "langchain-google-genai", specifier = ">=4.0.0" },
|
| 860 |
{ name = "langchain-groq", specifier = ">=1.1.0" },
|
|
|
|
| 861 |
{ name = "langchain-tavily", specifier = ">=0.2.13" },
|
|
|
|
| 862 |
{ name = "langgraph", specifier = ">=1.0.4" },
|
| 863 |
{ name = "langserve", extras = ["all"], specifier = ">=0.3.3" },
|
| 864 |
{ name = "mcp", specifier = ">=1.24.0" },
|
|
@@ -869,6 +895,7 @@ requires-dist = [
|
|
| 869 |
{ name = "pillow", specifier = ">=12.0.0" },
|
| 870 |
{ name = "pypdf", specifier = ">=6.4.1" },
|
| 871 |
{ name = "python-dotenv", specifier = ">=1.2.1" },
|
|
|
|
| 872 |
{ name = "sentence-transformers", specifier = ">=5.2.0" },
|
| 873 |
{ name = "tavily-python", specifier = ">=0.7.14" },
|
| 874 |
{ name = "torch", specifier = ">=2.9.1" },
|
|
@@ -908,6 +935,7 @@ wheels = [
|
|
| 908 |
]
|
| 909 |
|
| 910 |
[[package]]
|
|
|
|
| 911 |
name = "dnspython"
|
| 912 |
version = "2.8.0"
|
| 913 |
source = { registry = "https://pypi.org/simple" }
|
|
@@ -932,6 +960,14 @@ source = { registry = "https://pypi.org/simple" }
|
|
| 932 |
sdist = { url = "https://files.pythonhosted.org/packages/ae/b6/03bb70946330e88ffec97aefd3ea75ba575cb2e762061e0e62a213befee8/docutils-0.22.4.tar.gz", hash = "sha256:4db53b1fde9abecbb74d91230d32ab626d94f6badfc575d6db9194a49df29968", size = 2291750, upload-time = "2025-12-18T19:00:26.443Z" }
|
| 933 |
wheels = [
|
| 934 |
{ url = "https://files.pythonhosted.org/packages/02/10/5da547df7a391dcde17f59520a231527b8571e6f46fc8efb02ccb370ab12/docutils-0.22.4-py3-none-any.whl", hash = "sha256:d0013f540772d1420576855455d050a2180186c91c15779301ac2ccb3eeb68de", size = 633196, upload-time = "2025-12-18T19:00:18.077Z" },
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 935 |
]
|
| 936 |
|
| 937 |
[[package]]
|
|
@@ -990,6 +1026,7 @@ wheels = [
|
|
| 990 |
]
|
| 991 |
|
| 992 |
[[package]]
|
|
|
|
| 993 |
name = "fastmcp"
|
| 994 |
version = "2.13.0"
|
| 995 |
source = { registry = "https://pypi.org/simple" }
|
|
@@ -1012,6 +1049,17 @@ dependencies = [
|
|
| 1012 |
sdist = { url = "https://files.pythonhosted.org/packages/bc/3b/c30af894db2c3ec439d0e4168ba7ce705474cabdd0a599033ad9a19ad977/fastmcp-2.13.0.tar.gz", hash = "sha256:57f7b7503363e1babc0d1a13af18252b80366a409e1de85f1256cce66a4bee35", size = 7767346, upload-time = "2025-10-25T12:54:10.957Z" }
|
| 1013 |
wheels = [
|
| 1014 |
{ url = "https://files.pythonhosted.org/packages/c0/7f/09942135f506953fc61bb81b9e5eaf50a8eea923b83d9135bd959168ef2d/fastmcp-2.13.0-py3-none-any.whl", hash = "sha256:bdff1399d3b7ebb79286edfd43eb660182432514a5ab8e4cbfb45f1d841d2aa0", size = 367134, upload-time = "2025-10-25T12:54:09.284Z" },
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1015 |
]
|
| 1016 |
|
| 1017 |
[[package]]
|
|
@@ -2116,6 +2164,19 @@ wheels = [
|
|
| 2116 |
{ url = "https://files.pythonhosted.org/packages/83/bd/9df897cbc98290bf71140104ee5b9777cf5291afb80333aa7da5a497339b/langchain_core-1.2.5-py3-none-any.whl", hash = "sha256:3255944ef4e21b2551facb319bfc426057a40247c0a05de5bd6f2fc021fbfa34", size = 484851, upload-time = "2025-12-22T23:45:30.525Z" },
|
| 2117 |
]
|
| 2118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2119 |
[[package]]
|
| 2120 |
name = "langchain-google-genai"
|
| 2121 |
version = "4.1.1"
|
|
@@ -2144,6 +2205,19 @@ wheels = [
|
|
| 2144 |
{ url = "https://files.pythonhosted.org/packages/af/4a/3d6227a16fe9f79968414b50e50869519378b20653805e2e8fab283908e6/langchain_groq-1.1.1-py3-none-any.whl", hash = "sha256:1c6d5146f60205dcde09d7e47bb5291c295d3f0c7bcd2417e4d3a73a04bd1050", size = 19039, upload-time = "2025-12-12T22:00:45.86Z" },
|
| 2145 |
]
|
| 2146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2147 |
[[package]]
|
| 2148 |
name = "langchain-tavily"
|
| 2149 |
version = "0.2.16"
|
|
@@ -2989,6 +3063,19 @@ wheels = [
|
|
| 2989 |
{ url = "https://files.pythonhosted.org/packages/be/9c/92789c596b8df838baa98fa71844d84283302f7604ed565dafe5a6b5041a/oauthlib-3.3.1-py3-none-any.whl", hash = "sha256:88119c938d2b8fb88561af5f6ee0eec8cc8d552b7bb1f712743136eb7523b7a1", size = 160065, upload-time = "2025-06-19T22:48:06.508Z" },
|
| 2990 |
]
|
| 2991 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2992 |
[[package]]
|
| 2993 |
name = "onnxruntime"
|
| 2994 |
version = "1.23.2"
|
|
@@ -4101,6 +4188,18 @@ wheels = [
|
|
| 4101 |
{ url = "https://files.pythonhosted.org/packages/f1/12/de94a39c2ef588c7e6455cfbe7343d3b2dc9d6b6b2f40c4c6565744c873d/pyyaml-6.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:ebc55a14a21cb14062aa4162f906cd962b28e2e9ea38f9b4391244cd8de4ae0b", size = 149341, upload-time = "2025-09-25T21:32:56.828Z" },
|
| 4102 |
]
|
| 4103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4104 |
[[package]]
|
| 4105 |
name = "referencing"
|
| 4106 |
version = "0.36.2"
|
|
@@ -4585,6 +4684,12 @@ wheels = [
|
|
| 4585 |
{ url = "https://files.pythonhosted.org/packages/a3/dc/17031897dae0efacfea57dfd3a82fdd2a2aeb58e0ff71b77b87e44edc772/setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922", size = 1201486, upload-time = "2025-05-27T00:56:49.664Z" },
|
| 4586 |
]
|
| 4587 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4588 |
[[package]]
|
| 4589 |
name = "shellingham"
|
| 4590 |
version = "1.5.4"
|
|
|
|
| 179 |
{ url = "https://files.pythonhosted.org/packages/7f/9c/36c5c37947ebfb8c7f22e0eb6e4d188ee2d53aa3880f3f2744fb894f0cb1/anyio-4.12.0-py3-none-any.whl", hash = "sha256:dad2376a628f98eeca4881fc56cd06affd18f659b17a747d3ff0307ced94b1bb", size = 113362, upload-time = "2025-11-28T23:36:57.897Z" },
|
| 180 |
]
|
| 181 |
|
| 182 |
+
[[package]]
|
| 183 |
+
name = "arxiv"
|
| 184 |
+
version = "2.3.1"
|
| 185 |
+
source = { registry = "https://pypi.org/simple" }
|
| 186 |
+
dependencies = [
|
| 187 |
+
{ name = "feedparser" },
|
| 188 |
+
{ name = "requests" },
|
| 189 |
+
]
|
| 190 |
+
sdist = { url = "https://files.pythonhosted.org/packages/dd/95/65e38ddfb54762a8f1777bbe80da2cebf7941376e67a2212de487d9372db/arxiv-2.3.1.tar.gz", hash = "sha256:08567185dfc102c8d349de4b9e84dfde0af46d6402486e3009afc90f8ccf9709", size = 16692, upload-time = "2025-11-13T06:22:59.853Z" }
|
| 191 |
+
wheels = [
|
| 192 |
+
{ url = "https://files.pythonhosted.org/packages/90/7f/340847023184305a6378d75ec71e1dd38a942dfe71b7c29314b8fbe26948/arxiv-2.3.1-py3-none-any.whl", hash = "sha256:eb5a0b76808cc0a16de0c1448df0f927a3cf576096686d8e335a98b8872df1be", size = 11565, upload-time = "2025-11-13T06:22:58.662Z" },
|
| 193 |
+
]
|
| 194 |
+
|
| 195 |
[[package]]
|
| 196 |
name = "attrs"
|
| 197 |
version = "25.4.0"
|
|
|
|
| 821 |
source = { virtual = "." }
|
| 822 |
dependencies = [
|
| 823 |
{ name = "accelerate" },
|
| 824 |
+
{ name = "arxiv" },
|
| 825 |
+
{ name = "chromadb" },
|
| 826 |
+
{ name = "docx2txt" },
|
| 827 |
{ name = "einops" },
|
| 828 |
{ name = "fastapi" },
|
| 829 |
{ name = "google-api-python-client" },
|
|
|
|
| 836 |
{ name = "langchain" },
|
| 837 |
{ name = "langchain-chroma" },
|
| 838 |
{ name = "langchain-community" },
|
| 839 |
+
{ name = "langchain-experimental" },
|
| 840 |
{ name = "langchain-google-genai" },
|
| 841 |
{ name = "langchain-groq" },
|
| 842 |
+
{ name = "langchain-ollama" },
|
| 843 |
{ name = "langchain-tavily" },
|
| 844 |
+
{ name = "langchain-text-splitters" },
|
| 845 |
{ name = "langgraph" },
|
| 846 |
{ name = "langserve", extra = ["all"] },
|
| 847 |
{ name = "mcp" },
|
|
|
|
| 852 |
{ name = "pillow" },
|
| 853 |
{ name = "pypdf" },
|
| 854 |
{ name = "python-dotenv" },
|
| 855 |
+
{ name = "rank-bm25" },
|
| 856 |
{ name = "sentence-transformers" },
|
| 857 |
{ name = "tavily-python" },
|
| 858 |
{ name = "torch" },
|
|
|
|
| 864 |
[package.metadata]
|
| 865 |
requires-dist = [
|
| 866 |
{ name = "accelerate", specifier = ">=1.12.0" },
|
| 867 |
+
{ name = "arxiv", specifier = ">=2.3.1" },
|
| 868 |
+
{ name = "chromadb", specifier = ">=0.4.22" },
|
| 869 |
+
{ name = "docx2txt", specifier = ">=0.8" },
|
| 870 |
{ name = "einops", specifier = ">=0.8.1" },
|
| 871 |
{ name = "fastapi", specifier = ">=0.124.2" },
|
| 872 |
{ name = "google-api-python-client", specifier = ">=2.187.0" },
|
|
|
|
| 879 |
{ name = "langchain", specifier = ">=1.1.3" },
|
| 880 |
{ name = "langchain-chroma", specifier = ">=1.0.0" },
|
| 881 |
{ name = "langchain-community", specifier = ">=0.4.1" },
|
| 882 |
+
{ name = "langchain-experimental", specifier = ">=0.0.50" },
|
| 883 |
{ name = "langchain-google-genai", specifier = ">=4.0.0" },
|
| 884 |
{ name = "langchain-groq", specifier = ">=1.1.0" },
|
| 885 |
+
{ name = "langchain-ollama", specifier = ">=0.1.0" },
|
| 886 |
{ name = "langchain-tavily", specifier = ">=0.2.13" },
|
| 887 |
+
{ name = "langchain-text-splitters", specifier = ">=0.0.1" },
|
| 888 |
{ name = "langgraph", specifier = ">=1.0.4" },
|
| 889 |
{ name = "langserve", extras = ["all"], specifier = ">=0.3.3" },
|
| 890 |
{ name = "mcp", specifier = ">=1.24.0" },
|
|
|
|
| 895 |
{ name = "pillow", specifier = ">=12.0.0" },
|
| 896 |
{ name = "pypdf", specifier = ">=6.4.1" },
|
| 897 |
{ name = "python-dotenv", specifier = ">=1.2.1" },
|
| 898 |
+
{ name = "rank-bm25", specifier = ">=0.2.2" },
|
| 899 |
{ name = "sentence-transformers", specifier = ">=5.2.0" },
|
| 900 |
{ name = "tavily-python", specifier = ">=0.7.14" },
|
| 901 |
{ name = "torch", specifier = ">=2.9.1" },
|
|
|
|
| 935 |
]
|
| 936 |
|
| 937 |
[[package]]
|
| 938 |
+
<<<<<<< HEAD
|
| 939 |
name = "dnspython"
|
| 940 |
version = "2.8.0"
|
| 941 |
source = { registry = "https://pypi.org/simple" }
|
|
|
|
| 960 |
sdist = { url = "https://files.pythonhosted.org/packages/ae/b6/03bb70946330e88ffec97aefd3ea75ba575cb2e762061e0e62a213befee8/docutils-0.22.4.tar.gz", hash = "sha256:4db53b1fde9abecbb74d91230d32ab626d94f6badfc575d6db9194a49df29968", size = 2291750, upload-time = "2025-12-18T19:00:26.443Z" }
|
| 961 |
wheels = [
|
| 962 |
{ url = "https://files.pythonhosted.org/packages/02/10/5da547df7a391dcde17f59520a231527b8571e6f46fc8efb02ccb370ab12/docutils-0.22.4-py3-none-any.whl", hash = "sha256:d0013f540772d1420576855455d050a2180186c91c15779301ac2ccb3eeb68de", size = 633196, upload-time = "2025-12-18T19:00:18.077Z" },
|
| 963 |
+
=======
|
| 964 |
+
name = "docx2txt"
|
| 965 |
+
version = "0.9"
|
| 966 |
+
source = { registry = "https://pypi.org/simple" }
|
| 967 |
+
sdist = { url = "https://files.pythonhosted.org/packages/ea/07/4486a038624e885e227fe79111914c01f55aa70a51920ff1a7f2bd216d10/docx2txt-0.9.tar.gz", hash = "sha256:18013f6229b14909028b19aa7bf4f8f3d6e4632d7b089ab29f7f0a4d1f660e28", size = 3613, upload-time = "2025-03-24T20:59:25.21Z" }
|
| 968 |
+
wheels = [
|
| 969 |
+
{ url = "https://files.pythonhosted.org/packages/d6/51/756e71bec48ece0ecc2a10e921ef2756e197dcb7e478f2b43673b6683902/docx2txt-0.9-py3-none-any.whl", hash = "sha256:e3718c0653fd6f2fcf4b51b02a61452ad1c38a4c163bcf0a6fd9486cd38f529a", size = 4025, upload-time = "2025-03-24T20:59:24.394Z" },
|
| 970 |
+
>>>>>>> 5beccbe9dfa0ef53e4123976ad54e2f1c28b72f8
|
| 971 |
]
|
| 972 |
|
| 973 |
[[package]]
|
|
|
|
| 1026 |
]
|
| 1027 |
|
| 1028 |
[[package]]
|
| 1029 |
+
<<<<<<< HEAD
|
| 1030 |
name = "fastmcp"
|
| 1031 |
version = "2.13.0"
|
| 1032 |
source = { registry = "https://pypi.org/simple" }
|
|
|
|
| 1049 |
sdist = { url = "https://files.pythonhosted.org/packages/bc/3b/c30af894db2c3ec439d0e4168ba7ce705474cabdd0a599033ad9a19ad977/fastmcp-2.13.0.tar.gz", hash = "sha256:57f7b7503363e1babc0d1a13af18252b80366a409e1de85f1256cce66a4bee35", size = 7767346, upload-time = "2025-10-25T12:54:10.957Z" }
|
| 1050 |
wheels = [
|
| 1051 |
{ url = "https://files.pythonhosted.org/packages/c0/7f/09942135f506953fc61bb81b9e5eaf50a8eea923b83d9135bd959168ef2d/fastmcp-2.13.0-py3-none-any.whl", hash = "sha256:bdff1399d3b7ebb79286edfd43eb660182432514a5ab8e4cbfb45f1d841d2aa0", size = 367134, upload-time = "2025-10-25T12:54:09.284Z" },
|
| 1052 |
+
=======
|
| 1053 |
+
name = "feedparser"
|
| 1054 |
+
version = "6.0.12"
|
| 1055 |
+
source = { registry = "https://pypi.org/simple" }
|
| 1056 |
+
dependencies = [
|
| 1057 |
+
{ name = "sgmllib3k" },
|
| 1058 |
+
]
|
| 1059 |
+
sdist = { url = "https://files.pythonhosted.org/packages/dc/79/db7edb5e77d6dfbc54d7d9df72828be4318275b2e580549ff45a962f6461/feedparser-6.0.12.tar.gz", hash = "sha256:64f76ce90ae3e8ef5d1ede0f8d3b50ce26bcce71dd8ae5e82b1cd2d4a5f94228", size = 286579, upload-time = "2025-09-10T13:33:59.486Z" }
|
| 1060 |
+
wheels = [
|
| 1061 |
+
{ url = "https://files.pythonhosted.org/packages/4e/eb/c96d64137e29ae17d83ad2552470bafe3a7a915e85434d9942077d7fd011/feedparser-6.0.12-py3-none-any.whl", hash = "sha256:6bbff10f5a52662c00a2e3f86a38928c37c48f77b3c511aedcd51de933549324", size = 81480, upload-time = "2025-09-10T13:33:58.022Z" },
|
| 1062 |
+
>>>>>>> 5beccbe9dfa0ef53e4123976ad54e2f1c28b72f8
|
| 1063 |
]
|
| 1064 |
|
| 1065 |
[[package]]
|
|
|
|
| 2164 |
{ url = "https://files.pythonhosted.org/packages/83/bd/9df897cbc98290bf71140104ee5b9777cf5291afb80333aa7da5a497339b/langchain_core-1.2.5-py3-none-any.whl", hash = "sha256:3255944ef4e21b2551facb319bfc426057a40247c0a05de5bd6f2fc021fbfa34", size = 484851, upload-time = "2025-12-22T23:45:30.525Z" },
|
| 2165 |
]
|
| 2166 |
|
| 2167 |
+
[[package]]
|
| 2168 |
+
name = "langchain-experimental"
|
| 2169 |
+
version = "0.4.1"
|
| 2170 |
+
source = { registry = "https://pypi.org/simple" }
|
| 2171 |
+
dependencies = [
|
| 2172 |
+
{ name = "langchain-community" },
|
| 2173 |
+
{ name = "langchain-core" },
|
| 2174 |
+
]
|
| 2175 |
+
sdist = { url = "https://files.pythonhosted.org/packages/a2/ec/6fe7b2e3c105b4f4fc6b943d8fc1b5b10f883429edc36c58a09fc2e28419/langchain_experimental-0.4.1.tar.gz", hash = "sha256:ab6b19a0b98fbc15225fbfcf096176fec339b7e3e930bcf328bb717985fc1da5", size = 170449, upload-time = "2025-12-11T05:30:48.455Z" }
|
| 2176 |
+
wheels = [
|
| 2177 |
+
{ url = "https://files.pythonhosted.org/packages/24/fa/fb2c8b6418e1c9ef50c82b3b6e0184bce321582577240bb4b8ed3274a4aa/langchain_experimental-0.4.1-py3-none-any.whl", hash = "sha256:b6ee2f42b50aaadb45e581439ecf5ee50f3a6a0986d52e74d1e64721309e387d", size = 210096, upload-time = "2025-12-11T05:30:47.234Z" },
|
| 2178 |
+
]
|
| 2179 |
+
|
| 2180 |
[[package]]
|
| 2181 |
name = "langchain-google-genai"
|
| 2182 |
version = "4.1.1"
|
|
|
|
| 2205 |
{ url = "https://files.pythonhosted.org/packages/af/4a/3d6227a16fe9f79968414b50e50869519378b20653805e2e8fab283908e6/langchain_groq-1.1.1-py3-none-any.whl", hash = "sha256:1c6d5146f60205dcde09d7e47bb5291c295d3f0c7bcd2417e4d3a73a04bd1050", size = 19039, upload-time = "2025-12-12T22:00:45.86Z" },
|
| 2206 |
]
|
| 2207 |
|
| 2208 |
+
[[package]]
|
| 2209 |
+
name = "langchain-ollama"
|
| 2210 |
+
version = "1.0.1"
|
| 2211 |
+
source = { registry = "https://pypi.org/simple" }
|
| 2212 |
+
dependencies = [
|
| 2213 |
+
{ name = "langchain-core" },
|
| 2214 |
+
{ name = "ollama" },
|
| 2215 |
+
]
|
| 2216 |
+
sdist = { url = "https://files.pythonhosted.org/packages/73/51/72cd04d74278f3575f921084f34280e2f837211dc008c9671c268c578afe/langchain_ollama-1.0.1.tar.gz", hash = "sha256:e37880c2f41cdb0895e863b1cfd0c2c840a117868b3f32e44fef42569e367443", size = 153850, upload-time = "2025-12-12T21:48:28.68Z" }
|
| 2217 |
+
wheels = [
|
| 2218 |
+
{ url = "https://files.pythonhosted.org/packages/e3/46/f2907da16dc5a5a6c679f83b7de21176178afad8d2ca635a581429580ef6/langchain_ollama-1.0.1-py3-none-any.whl", hash = "sha256:37eb939a4718a0255fe31e19fbb0def044746c717b01b97d397606ebc3e9b440", size = 29207, upload-time = "2025-12-12T21:48:27.832Z" },
|
| 2219 |
+
]
|
| 2220 |
+
|
| 2221 |
[[package]]
|
| 2222 |
name = "langchain-tavily"
|
| 2223 |
version = "0.2.16"
|
|
|
|
| 3063 |
{ url = "https://files.pythonhosted.org/packages/be/9c/92789c596b8df838baa98fa71844d84283302f7604ed565dafe5a6b5041a/oauthlib-3.3.1-py3-none-any.whl", hash = "sha256:88119c938d2b8fb88561af5f6ee0eec8cc8d552b7bb1f712743136eb7523b7a1", size = 160065, upload-time = "2025-06-19T22:48:06.508Z" },
|
| 3064 |
]
|
| 3065 |
|
| 3066 |
+
[[package]]
|
| 3067 |
+
name = "ollama"
|
| 3068 |
+
version = "0.6.1"
|
| 3069 |
+
source = { registry = "https://pypi.org/simple" }
|
| 3070 |
+
dependencies = [
|
| 3071 |
+
{ name = "httpx" },
|
| 3072 |
+
{ name = "pydantic" },
|
| 3073 |
+
]
|
| 3074 |
+
sdist = { url = "https://files.pythonhosted.org/packages/9d/5a/652dac4b7affc2b37b95386f8ae78f22808af09d720689e3d7a86b6ed98e/ollama-0.6.1.tar.gz", hash = "sha256:478c67546836430034b415ed64fa890fd3d1ff91781a9d548b3325274e69d7c6", size = 51620, upload-time = "2025-11-13T23:02:17.416Z" }
|
| 3075 |
+
wheels = [
|
| 3076 |
+
{ url = "https://files.pythonhosted.org/packages/47/4f/4a617ee93d8208d2bcf26b2d8b9402ceaed03e3853c754940e2290fed063/ollama-0.6.1-py3-none-any.whl", hash = "sha256:fc4c984b345735c5486faeee67d8a265214a31cbb828167782dc642ce0a2bf8c", size = 14354, upload-time = "2025-11-13T23:02:16.292Z" },
|
| 3077 |
+
]
|
| 3078 |
+
|
| 3079 |
[[package]]
|
| 3080 |
name = "onnxruntime"
|
| 3081 |
version = "1.23.2"
|
|
|
|
| 4188 |
{ url = "https://files.pythonhosted.org/packages/f1/12/de94a39c2ef588c7e6455cfbe7343d3b2dc9d6b6b2f40c4c6565744c873d/pyyaml-6.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:ebc55a14a21cb14062aa4162f906cd962b28e2e9ea38f9b4391244cd8de4ae0b", size = 149341, upload-time = "2025-09-25T21:32:56.828Z" },
|
| 4189 |
]
|
| 4190 |
|
| 4191 |
+
[[package]]
|
| 4192 |
+
name = "rank-bm25"
|
| 4193 |
+
version = "0.2.2"
|
| 4194 |
+
source = { registry = "https://pypi.org/simple" }
|
| 4195 |
+
dependencies = [
|
| 4196 |
+
{ name = "numpy" },
|
| 4197 |
+
]
|
| 4198 |
+
sdist = { url = "https://files.pythonhosted.org/packages/fc/0a/f9579384aa017d8b4c15613f86954b92a95a93d641cc849182467cf0bb3b/rank_bm25-0.2.2.tar.gz", hash = "sha256:096ccef76f8188563419aaf384a02f0ea459503fdf77901378d4fd9d87e5e51d", size = 8347, upload-time = "2022-02-16T12:10:52.196Z" }
|
| 4199 |
+
wheels = [
|
| 4200 |
+
{ url = "https://files.pythonhosted.org/packages/2a/21/f691fb2613100a62b3fa91e9988c991e9ca5b89ea31c0d3152a3210344f9/rank_bm25-0.2.2-py3-none-any.whl", hash = "sha256:7bd4a95571adadfc271746fa146a4bcfd89c0cf731e49c3d1ad863290adbe8ae", size = 8584, upload-time = "2022-02-16T12:10:50.626Z" },
|
| 4201 |
+
]
|
| 4202 |
+
|
| 4203 |
[[package]]
|
| 4204 |
name = "referencing"
|
| 4205 |
version = "0.36.2"
|
|
|
|
| 4684 |
{ url = "https://files.pythonhosted.org/packages/a3/dc/17031897dae0efacfea57dfd3a82fdd2a2aeb58e0ff71b77b87e44edc772/setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922", size = 1201486, upload-time = "2025-05-27T00:56:49.664Z" },
|
| 4685 |
]
|
| 4686 |
|
| 4687 |
+
[[package]]
|
| 4688 |
+
name = "sgmllib3k"
|
| 4689 |
+
version = "1.0.0"
|
| 4690 |
+
source = { registry = "https://pypi.org/simple" }
|
| 4691 |
+
sdist = { url = "https://files.pythonhosted.org/packages/9e/bd/3704a8c3e0942d711c1299ebf7b9091930adae6675d7c8f476a7ce48653c/sgmllib3k-1.0.0.tar.gz", hash = "sha256:7868fb1c8bfa764c1ac563d3cf369c381d1325d36124933a726f29fcdaa812e9", size = 5750, upload-time = "2010-08-24T14:33:52.445Z" }
|
| 4692 |
+
|
| 4693 |
[[package]]
|
| 4694 |
name = "shellingham"
|
| 4695 |
version = "1.5.4"
|