Abhinav Gavireddi
commited on
Commit
·
69979b2
1
Parent(s):
7c8a700
[fix]: fixed streamlit application slow startup
Browse files- Dockerfile +7 -7
- requirements.txt +1 -6
- src/__init__.py +45 -28
- src/config.py +0 -38
- src/gpp.py +25 -32
- src/qa.py +5 -15
- src/retriever.py +6 -13
- src/utils.py +2 -7
Dockerfile
CHANGED
|
@@ -15,21 +15,21 @@ RUN apt-get update && \
|
|
| 15 |
# for hnswlib (needed for OpenMP)
|
| 16 |
libgomp1 \
|
| 17 |
curl \
|
| 18 |
-
git \
|
| 19 |
-
|
| 20 |
|
| 21 |
# Copy and install Python dependencies
|
| 22 |
COPY requirements.txt ./
|
| 23 |
-
RUN pip install
|
|
|
|
| 24 |
|
| 25 |
-
#
|
| 26 |
-
RUN
|
| 27 |
-
pip install git+https://github.com/opendatalab/MinerU.git@dev && \
|
| 28 |
-
curl -L https://github.com/opendatalab/MinerU/raw/dev/scripts/download_models_hf.py -o download_models_hf.py && \
|
| 29 |
python download_models_hf.py
|
| 30 |
|
| 31 |
# Copy application code
|
| 32 |
COPY src/ ./src/
|
|
|
|
| 33 |
# COPY tests/ ./tests/
|
| 34 |
COPY app.py .
|
| 35 |
|
|
|
|
| 15 |
# for hnswlib (needed for OpenMP)
|
| 16 |
libgomp1 \
|
| 17 |
curl \
|
| 18 |
+
git && \
|
| 19 |
+
rm -rf /var/lib/apt/lists/*
|
| 20 |
|
| 21 |
# Copy and install Python dependencies
|
| 22 |
COPY requirements.txt ./
|
| 23 |
+
RUN pip install uv && \
|
| 24 |
+
uv pip install --no-cache-dir -r requirements.txt
|
| 25 |
|
| 26 |
+
# Download models (if needed at build time)
|
| 27 |
+
RUN curl -L https://github.com/opendatalab/MinerU/raw/dev/scripts/download_models_hf.py -o download_models_hf.py && \
|
|
|
|
|
|
|
| 28 |
python download_models_hf.py
|
| 29 |
|
| 30 |
# Copy application code
|
| 31 |
COPY src/ ./src/
|
| 32 |
+
|
| 33 |
# COPY tests/ ./tests/
|
| 34 |
COPY app.py .
|
| 35 |
|
requirements.txt
CHANGED
|
@@ -26,7 +26,6 @@ ultralytics>=8.3.48
|
|
| 26 |
rapid-table>=1.0.3,<2.0.0
|
| 27 |
doclayout-yolo==0.0.2b1
|
| 28 |
dill>=0.3.9,<1
|
| 29 |
-
rapid_table>=1.0.3,<2.0.0
|
| 30 |
PyYAML>=6.0.2,<7
|
| 31 |
ftfy>=6.3.1,<7
|
| 32 |
openai>=1.70.0,<2
|
|
@@ -37,9 +36,5 @@ shapely>=2.0.7,<3
|
|
| 37 |
pyclipper>=1.3.0,<2
|
| 38 |
omegaconf>=2.3.0,<3
|
| 39 |
tqdm>=4.67.1
|
| 40 |
-
|
| 41 |
# MinerU
|
| 42 |
-
git+https://github.com/opendatalab/MinerU.git@dev
|
| 43 |
-
|
| 44 |
-
# Testing
|
| 45 |
-
pytest>=7.0
|
|
|
|
| 26 |
rapid-table>=1.0.3,<2.0.0
|
| 27 |
doclayout-yolo==0.0.2b1
|
| 28 |
dill>=0.3.9,<1
|
|
|
|
| 29 |
PyYAML>=6.0.2,<7
|
| 30 |
ftfy>=6.3.1,<7
|
| 31 |
openai>=1.70.0,<2
|
|
|
|
| 36 |
pyclipper>=1.3.0,<2
|
| 37 |
omegaconf>=2.3.0,<3
|
| 38 |
tqdm>=4.67.1
|
|
|
|
| 39 |
# MinerU
|
| 40 |
+
git+https://github.com/opendatalab/MinerU.git@dev
|
|
|
|
|
|
|
|
|
src/__init__.py
CHANGED
|
@@ -2,37 +2,54 @@ import os
|
|
| 2 |
from dotenv import load_dotenv
|
| 3 |
import bleach
|
| 4 |
|
| 5 |
-
import logging
|
| 6 |
-
import sys
|
| 7 |
-
import structlog
|
| 8 |
-
|
| 9 |
load_dotenv()
|
| 10 |
|
| 11 |
-
os.system('python src/ghm.py')
|
| 12 |
-
|
| 13 |
-
def configure_logging():
|
| 14 |
-
structlog.configure(
|
| 15 |
-
processors=[
|
| 16 |
-
structlog.processors.TimeStamper(fmt="iso"),
|
| 17 |
-
structlog.processors.JSONRenderer()
|
| 18 |
-
],
|
| 19 |
-
context_class=dict,
|
| 20 |
-
logger_factory=structlog.stdlib.LoggerFactory(),
|
| 21 |
-
wrapper_class=structlog.stdlib.BoundLogger,
|
| 22 |
-
cache_logger_on_first_use=True,
|
| 23 |
-
)
|
| 24 |
-
if not logging.getLogger().handlers:
|
| 25 |
-
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
| 26 |
-
|
| 27 |
-
def get_env(name):
|
| 28 |
-
val = os.getenv(name)
|
| 29 |
-
if not val:
|
| 30 |
-
raise RuntimeError(f"Missing required secret: {name}")
|
| 31 |
-
return val
|
| 32 |
-
|
| 33 |
def sanitize_html(raw):
|
| 34 |
# allow only text and basic tags
|
| 35 |
return bleach.clean(raw, tags=[], strip=True)
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from dotenv import load_dotenv
|
| 3 |
import bleach
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
load_dotenv()
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
def sanitize_html(raw):
|
| 8 |
# allow only text and basic tags
|
| 9 |
return bleach.clean(raw, tags=[], strip=True)
|
| 10 |
|
| 11 |
+
"""
|
| 12 |
+
Central configuration for the entire Document Intelligence app.
|
| 13 |
+
All modules import from here rather than hard-coding values.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
OPENAI_EMBEDDING_MODEL = os.getenv(
|
| 17 |
+
"OPENAI_EMBEDDING_MODEL", "text-embedding-ada-002"
|
| 18 |
+
)
|
| 19 |
+
class EmbeddingConfig:
|
| 20 |
+
PROVIDER = os.getenv("EMBEDDING_PROVIDER",'HF')
|
| 21 |
+
TEXT_MODEL = os.getenv('TEXT_EMBED_MODEL', 'sentence-transformers/all-MiniLM-L6-v2')
|
| 22 |
+
META_MODEL = os.getenv('META_EMBED_MODEL', 'sentence-transformers/all-MiniLM-L6-v2')
|
| 23 |
+
|
| 24 |
+
class RetrieverConfig:
|
| 25 |
+
PROVIDER = os.getenv("EMBEDDING_PROVIDER",'HF')
|
| 26 |
+
TOP_K = int(os.getenv('RETRIEVER_TOP_K', 10))
|
| 27 |
+
DENSE_MODEL = 'sentence-transformers/all-MiniLM-L6-v2'
|
| 28 |
+
ANN_TOP = int(os.getenv('ANN_TOP', 50))
|
| 29 |
+
|
| 30 |
+
class RerankerConfig:
|
| 31 |
+
@staticmethod
|
| 32 |
+
def get_device():
|
| 33 |
+
import torch
|
| 34 |
+
return 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 35 |
+
MODEL_NAME = os.getenv('RERANKER_MODEL', 'BAAI/bge-reranker-v2-Gemma')
|
| 36 |
+
DEVICE = get_device()
|
| 37 |
+
|
| 38 |
+
class GPPConfig:
|
| 39 |
+
CHUNK_TOKEN_SIZE = int(os.getenv('CHUNK_TOKEN_SIZE', 256))
|
| 40 |
+
DEDUP_SIM_THRESHOLD = float(os.getenv('DEDUP_SIM_THRESHOLD', 0.9))
|
| 41 |
+
EXPANSION_SIM_THRESHOLD = float(os.getenv('EXPANSION_SIM_THRESHOLD', 0.85))
|
| 42 |
+
COREF_CONTEXT_SIZE = int(os.getenv('COREF_CONTEXT_SIZE', 3))
|
| 43 |
+
|
| 44 |
+
class GPPConfig:
|
| 45 |
+
"""
|
| 46 |
+
Configuration for GPP pipeline.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
CHUNK_TOKEN_SIZE = 256
|
| 50 |
+
DEDUP_SIM_THRESHOLD = 0.9
|
| 51 |
+
EXPANSION_SIM_THRESHOLD = 0.85
|
| 52 |
+
COREF_CONTEXT_SIZE = 3
|
| 53 |
+
HNSW_EF_CONSTRUCTION = int(os.getenv("HNSW_EF_CONSTRUCTION", "200"))
|
| 54 |
+
HNSW_M = int(os.getenv("HNSW_M", "16"))
|
| 55 |
+
HNSW_EF_SEARCH = int(os.getenv("HNSW_EF_SEARCH", "50"))
|
src/config.py
DELETED
|
@@ -1,38 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Central configuration for the entire Document Intelligence app.
|
| 3 |
-
All modules import from here rather than hard-coding values.
|
| 4 |
-
"""
|
| 5 |
-
import os
|
| 6 |
-
|
| 7 |
-
# class RedisConfig:
|
| 8 |
-
# HOST = os.getenv('REDIS_HOST', 'localhost')
|
| 9 |
-
# PORT = int(os.getenv('REDIS_PORT', 6379))
|
| 10 |
-
# DB = int(os.getenv('REDIS_DB', 0))
|
| 11 |
-
# VECTOR_INDEX = os.getenv('REDIS_VECTOR_INDEX', 'gpp_vectors')
|
| 12 |
-
|
| 13 |
-
OPENAI_EMBEDDING_MODEL = os.getenv(
|
| 14 |
-
"OPENAI_EMBEDDING_MODEL", "text-embedding-ada-002"
|
| 15 |
-
)
|
| 16 |
-
class EmbeddingConfig:
|
| 17 |
-
PROVIDER = os.getenv("EMBEDDING_PROVIDER",'HF')
|
| 18 |
-
TEXT_MODEL = os.getenv('TEXT_EMBED_MODEL', 'sentence-transformers/all-MiniLM-L6-v2')
|
| 19 |
-
META_MODEL = os.getenv('META_EMBED_MODEL', 'sentence-transformers/all-MiniLM-L6-v2')
|
| 20 |
-
# TEXT_MODEL = OPENAI_EMBEDDING_MODEL
|
| 21 |
-
# META_MODEL = OPENAI_EMBEDDING_MODEL
|
| 22 |
-
|
| 23 |
-
class RetrieverConfig:
|
| 24 |
-
PROVIDER = os.getenv("EMBEDDING_PROVIDER",'HF')
|
| 25 |
-
TOP_K = int(os.getenv('RETRIEVER_TOP_K', 10)) # number of candidates per retrieval path
|
| 26 |
-
DENSE_MODEL = 'sentence-transformers/all-MiniLM-L6-v2'
|
| 27 |
-
# DENSE_MODEL = OPENAI_EMBEDDING_MODEL
|
| 28 |
-
ANN_TOP = int(os.getenv('ANN_TOP', 50))
|
| 29 |
-
|
| 30 |
-
class RerankerConfig:
|
| 31 |
-
MODEL_NAME = os.getenv('RERANKER_MODEL', 'BAAI/bge-reranker-v2-Gemma')
|
| 32 |
-
DEVICE = os.getenv('RERANKER_DEVICE', 'cuda' if os.getenv('CUDA_VISIBLE_DEVICES') else 'cpu')
|
| 33 |
-
|
| 34 |
-
class GPPConfig:
|
| 35 |
-
CHUNK_TOKEN_SIZE = int(os.getenv('CHUNK_TOKEN_SIZE', 256))
|
| 36 |
-
DEDUP_SIM_THRESHOLD = float(os.getenv('DEDUP_SIM_THRESHOLD', 0.9))
|
| 37 |
-
EXPANSION_SIM_THRESHOLD = float(os.getenv('EXPANSION_SIM_THRESHOLD', 0.85))
|
| 38 |
-
COREF_CONTEXT_SIZE = int(os.getenv('COREF_CONTEXT_SIZE', 3))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/gpp.py
CHANGED
|
@@ -18,23 +18,8 @@ import json
|
|
| 18 |
from typing import List, Dict, Any, Optional
|
| 19 |
import re
|
| 20 |
|
| 21 |
-
from
|
| 22 |
-
from
|
| 23 |
-
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
|
| 24 |
-
from magic_pdf.config.enums import SupportedPdfParseMethod
|
| 25 |
-
|
| 26 |
-
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 27 |
-
from sentence_transformers import SentenceTransformer
|
| 28 |
-
from rank_bm25 import BM25Okapi
|
| 29 |
-
import numpy as np
|
| 30 |
-
import hnswlib
|
| 31 |
-
|
| 32 |
-
from src.config import EmbeddingConfig
|
| 33 |
-
from src.utils import OpenAIEmbedder
|
| 34 |
-
|
| 35 |
-
# LLM client abstraction
|
| 36 |
-
from src.utils import LLMClient, logger
|
| 37 |
-
|
| 38 |
|
| 39 |
def parse_markdown_table(md: str) -> Optional[Dict[str, Any]]:
|
| 40 |
"""
|
|
@@ -60,23 +45,11 @@ def parse_markdown_table(md: str) -> Optional[Dict[str, Any]]:
|
|
| 60 |
return {"headers": headers, "rows": rows}
|
| 61 |
|
| 62 |
|
| 63 |
-
class GPPConfig:
|
| 64 |
-
"""
|
| 65 |
-
Configuration for GPP pipeline.
|
| 66 |
-
"""
|
| 67 |
-
|
| 68 |
-
CHUNK_TOKEN_SIZE = 256
|
| 69 |
-
DEDUP_SIM_THRESHOLD = 0.9
|
| 70 |
-
EXPANSION_SIM_THRESHOLD = 0.85
|
| 71 |
-
COREF_CONTEXT_SIZE = 3
|
| 72 |
-
HNSW_EF_CONSTRUCTION = int(os.getenv("HNSW_EF_CONSTRUCTION", "200"))
|
| 73 |
-
HNSW_M = int(os.getenv("HNSW_M", "16"))
|
| 74 |
-
HNSW_EF_SEARCH = int(os.getenv("HNSW_EF_SEARCH", "50"))
|
| 75 |
-
|
| 76 |
-
|
| 77 |
class GPP:
|
| 78 |
def __init__(self, config: GPPConfig):
|
| 79 |
self.config = config
|
|
|
|
|
|
|
| 80 |
# Embedding models
|
| 81 |
if EmbeddingConfig.PROVIDER == "openai":
|
| 82 |
self.text_embedder = OpenAIEmbedder(EmbeddingConfig.TEXT_MODEL)
|
|
@@ -97,6 +70,12 @@ class GPP:
|
|
| 97 |
dumps markdown, images, layout PDF, content_list JSON.
|
| 98 |
Returns parsed data plus file paths for UI traceability.
|
| 99 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
name = os.path.splitext(os.path.basename(pdf_path))[0]
|
| 101 |
img_dir = os.path.join(output_dir, "images")
|
| 102 |
os.makedirs(img_dir, exist_ok=True)
|
|
@@ -138,6 +117,9 @@ class GPP:
|
|
| 138 |
Creates chunks of ~CHUNK_TOKEN_SIZE tokens, but ensures any table/image block
|
| 139 |
becomes its own chunk (unsplittable), flushing current text chunk as needed.
|
| 140 |
"""
|
|
|
|
|
|
|
|
|
|
| 141 |
chunks, current, token_count = [], {"text": "", "type": None, "blocks": []}, 0
|
| 142 |
for blk in blocks:
|
| 143 |
btype = blk.get("type")
|
|
@@ -185,7 +167,10 @@ class GPP:
|
|
| 185 |
|
| 186 |
def deduplicate(self, chunks: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 187 |
try:
|
| 188 |
-
#
|
|
|
|
|
|
|
|
|
|
| 189 |
narrations = [c.get("narration", "") for c in chunks]
|
| 190 |
if EmbeddingConfig.PROVIDER == "openai":
|
| 191 |
embs = self.text_embedder.embed(narrations)
|
|
@@ -236,6 +221,9 @@ class GPP:
|
|
| 236 |
"""
|
| 237 |
Build BM25 index on token lists for sparse retrieval.
|
| 238 |
"""
|
|
|
|
|
|
|
|
|
|
| 239 |
tokenized = [c["narration"].split() for c in chunks]
|
| 240 |
self.bm25 = BM25Okapi(tokenized)
|
| 241 |
|
|
@@ -248,6 +236,11 @@ class GPP:
|
|
| 248 |
4. Dump human-readable chunk metadata (incl. section_summary)
|
| 249 |
for traceability in the UI.
|
| 250 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
# --- 1. Prepare embedder ---
|
| 252 |
if EmbeddingConfig.PROVIDER.lower() == "openai":
|
| 253 |
embedder = OpenAIEmbedder(EmbeddingConfig.TEXT_MODEL)
|
|
|
|
| 18 |
from typing import List, Dict, Any, Optional
|
| 19 |
import re
|
| 20 |
|
| 21 |
+
from src import EmbeddingConfig, GPPConfig
|
| 22 |
+
from src.utils import OpenAIEmbedder, LLMClient, logger
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
def parse_markdown_table(md: str) -> Optional[Dict[str, Any]]:
|
| 25 |
"""
|
|
|
|
| 45 |
return {"headers": headers, "rows": rows}
|
| 46 |
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
class GPP:
|
| 49 |
def __init__(self, config: GPPConfig):
|
| 50 |
self.config = config
|
| 51 |
+
# Lazy import heavy libraries
|
| 52 |
+
from sentence_transformers import SentenceTransformer
|
| 53 |
# Embedding models
|
| 54 |
if EmbeddingConfig.PROVIDER == "openai":
|
| 55 |
self.text_embedder = OpenAIEmbedder(EmbeddingConfig.TEXT_MODEL)
|
|
|
|
| 70 |
dumps markdown, images, layout PDF, content_list JSON.
|
| 71 |
Returns parsed data plus file paths for UI traceability.
|
| 72 |
"""
|
| 73 |
+
# Lazy import heavy libraries
|
| 74 |
+
from magic_pdf.data.data_reader_writer import FileBasedDataWriter, FileBasedDataReader
|
| 75 |
+
from magic_pdf.data.dataset import PymuDocDataset
|
| 76 |
+
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
|
| 77 |
+
from magic_pdf.config.enums import SupportedPdfParseMethod
|
| 78 |
+
|
| 79 |
name = os.path.splitext(os.path.basename(pdf_path))[0]
|
| 80 |
img_dir = os.path.join(output_dir, "images")
|
| 81 |
os.makedirs(img_dir, exist_ok=True)
|
|
|
|
| 117 |
Creates chunks of ~CHUNK_TOKEN_SIZE tokens, but ensures any table/image block
|
| 118 |
becomes its own chunk (unsplittable), flushing current text chunk as needed.
|
| 119 |
"""
|
| 120 |
+
# Lazy import heavy libraries
|
| 121 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 122 |
+
|
| 123 |
chunks, current, token_count = [], {"text": "", "type": None, "blocks": []}, 0
|
| 124 |
for blk in blocks:
|
| 125 |
btype = blk.get("type")
|
|
|
|
| 167 |
|
| 168 |
def deduplicate(self, chunks: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 169 |
try:
|
| 170 |
+
# Lazy import heavy libraries
|
| 171 |
+
import numpy as np
|
| 172 |
+
from sentence_transformers import SentenceTransformer
|
| 173 |
+
|
| 174 |
narrations = [c.get("narration", "") for c in chunks]
|
| 175 |
if EmbeddingConfig.PROVIDER == "openai":
|
| 176 |
embs = self.text_embedder.embed(narrations)
|
|
|
|
| 221 |
"""
|
| 222 |
Build BM25 index on token lists for sparse retrieval.
|
| 223 |
"""
|
| 224 |
+
# Lazy import heavy libraries
|
| 225 |
+
from rank_bm25 import BM25Okapi
|
| 226 |
+
|
| 227 |
tokenized = [c["narration"].split() for c in chunks]
|
| 228 |
self.bm25 = BM25Okapi(tokenized)
|
| 229 |
|
|
|
|
| 236 |
4. Dump human-readable chunk metadata (incl. section_summary)
|
| 237 |
for traceability in the UI.
|
| 238 |
"""
|
| 239 |
+
# Lazy import heavy libraries
|
| 240 |
+
import numpy as np
|
| 241 |
+
import hnswlib
|
| 242 |
+
from sentence_transformers import SentenceTransformer
|
| 243 |
+
|
| 244 |
# --- 1. Prepare embedder ---
|
| 245 |
if EmbeddingConfig.PROVIDER.lower() == "openai":
|
| 246 |
embedder = OpenAIEmbedder(EmbeddingConfig.TEXT_MODEL)
|
src/qa.py
CHANGED
|
@@ -11,26 +11,18 @@ Each component is modular and can be swapped or extended (e.g., add HyDE retriev
|
|
| 11 |
import os
|
| 12 |
from typing import List, Dict, Any, Tuple
|
| 13 |
|
| 14 |
-
from
|
| 15 |
-
from rank_bm25 import BM25Okapi
|
| 16 |
-
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 17 |
-
import torch
|
| 18 |
-
|
| 19 |
-
from src import sanitize_html
|
| 20 |
from src.utils import LLMClient, logger
|
| 21 |
from src.retriever import Retriever, RetrieverConfig
|
| 22 |
|
| 23 |
-
|
| 24 |
-
class RerankerConfig:
|
| 25 |
-
MODEL_NAME = os.getenv('RERANKER_MODEL', 'BAAI/bge-reranker-v2-Gemma')
|
| 26 |
-
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 27 |
-
|
| 28 |
class Reranker:
|
| 29 |
"""
|
| 30 |
Cross-encoder re-ranker using a transformer-based sequence classification model.
|
| 31 |
"""
|
| 32 |
def __init__(self, config: RerankerConfig):
|
| 33 |
try:
|
|
|
|
|
|
|
| 34 |
self.tokenizer = AutoTokenizer.from_pretrained(config.MODEL_NAME)
|
| 35 |
self.model = AutoModelForSequenceClassification.from_pretrained(config.MODEL_NAME)
|
| 36 |
self.model.to(config.DEVICE)
|
|
@@ -44,6 +36,7 @@ class Reranker:
|
|
| 44 |
logger.warning('No candidates provided to rerank.')
|
| 45 |
return []
|
| 46 |
try:
|
|
|
|
| 47 |
inputs = self.tokenizer(
|
| 48 |
[query] * len(candidates),
|
| 49 |
[c.get('narration', '') for c in candidates],
|
|
@@ -59,10 +52,7 @@ class Reranker:
|
|
| 59 |
logits = logits.squeeze(-1) # only squeeze if it's (batch, 1)
|
| 60 |
|
| 61 |
probs = torch.sigmoid(logits).cpu().numpy().flatten() # flatten always ensures 1D array
|
| 62 |
-
paired = []
|
| 63 |
-
for idx, c in enumerate(candidates):
|
| 64 |
-
score = float(probs[idx])
|
| 65 |
-
paired.append((c, score))
|
| 66 |
|
| 67 |
ranked = sorted(paired, key=lambda x: x[1], reverse=True)
|
| 68 |
return [c for c, _ in ranked[:top_k]]
|
|
|
|
| 11 |
import os
|
| 12 |
from typing import List, Dict, Any, Tuple
|
| 13 |
|
| 14 |
+
from src import RerankerConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
from src.utils import LLMClient, logger
|
| 16 |
from src.retriever import Retriever, RetrieverConfig
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
class Reranker:
|
| 19 |
"""
|
| 20 |
Cross-encoder re-ranker using a transformer-based sequence classification model.
|
| 21 |
"""
|
| 22 |
def __init__(self, config: RerankerConfig):
|
| 23 |
try:
|
| 24 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 25 |
+
import torch
|
| 26 |
self.tokenizer = AutoTokenizer.from_pretrained(config.MODEL_NAME)
|
| 27 |
self.model = AutoModelForSequenceClassification.from_pretrained(config.MODEL_NAME)
|
| 28 |
self.model.to(config.DEVICE)
|
|
|
|
| 36 |
logger.warning('No candidates provided to rerank.')
|
| 37 |
return []
|
| 38 |
try:
|
| 39 |
+
import torch
|
| 40 |
inputs = self.tokenizer(
|
| 41 |
[query] * len(candidates),
|
| 42 |
[c.get('narration', '') for c in candidates],
|
|
|
|
| 52 |
logits = logits.squeeze(-1) # only squeeze if it's (batch, 1)
|
| 53 |
|
| 54 |
probs = torch.sigmoid(logits).cpu().numpy().flatten() # flatten always ensures 1D array
|
| 55 |
+
paired = [(c, float(probs[idx])) for idx, c in enumerate(candidates)]
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
ranked = sorted(paired, key=lambda x: x[1], reverse=True)
|
| 58 |
return [c for c, _ in ranked[:top_k]]
|
src/retriever.py
CHANGED
|
@@ -1,27 +1,19 @@
|
|
| 1 |
import os
|
| 2 |
-
import numpy as np
|
| 3 |
-
import hnswlib
|
| 4 |
from typing import List, Dict, Any
|
| 5 |
|
| 6 |
-
from sentence_transformers import SentenceTransformer
|
| 7 |
-
from rank_bm25 import BM25Okapi
|
| 8 |
-
|
| 9 |
from src.config import RetrieverConfig
|
| 10 |
from src.utils import logger
|
| 11 |
|
| 12 |
-
|
| 13 |
class Retriever:
|
| 14 |
"""
|
| 15 |
Hybrid retriever combining BM25 sparse and dense retrieval (no Redis).
|
| 16 |
"""
|
| 17 |
def __init__(self, chunks: List[Dict[str, Any]], config: RetrieverConfig):
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
config (RetrieverConfig): Configuration for the retriever.
|
| 24 |
-
"""
|
| 25 |
self.chunks = chunks
|
| 26 |
try:
|
| 27 |
if not isinstance(chunks, list) or not all(isinstance(c, dict) for c in chunks):
|
|
@@ -58,6 +50,7 @@ class Retriever:
|
|
| 58 |
return []
|
| 59 |
tokenized = query.split()
|
| 60 |
try:
|
|
|
|
| 61 |
scores = self.bm25.get_scores(tokenized)
|
| 62 |
top_indices = np.argsort(scores)[::-1][:top_k]
|
| 63 |
return [self.chunks[i] for i in top_indices]
|
|
|
|
| 1 |
import os
|
|
|
|
|
|
|
| 2 |
from typing import List, Dict, Any
|
| 3 |
|
|
|
|
|
|
|
|
|
|
| 4 |
from src.config import RetrieverConfig
|
| 5 |
from src.utils import logger
|
| 6 |
|
|
|
|
| 7 |
class Retriever:
|
| 8 |
"""
|
| 9 |
Hybrid retriever combining BM25 sparse and dense retrieval (no Redis).
|
| 10 |
"""
|
| 11 |
def __init__(self, chunks: List[Dict[str, Any]], config: RetrieverConfig):
|
| 12 |
+
# Lazy import heavy libraries
|
| 13 |
+
import numpy as np
|
| 14 |
+
import hnswlib
|
| 15 |
+
from sentence_transformers import SentenceTransformer
|
| 16 |
+
from rank_bm25 import BM25Okapi
|
|
|
|
|
|
|
| 17 |
self.chunks = chunks
|
| 18 |
try:
|
| 19 |
if not isinstance(chunks, list) or not all(isinstance(c, dict) for c in chunks):
|
|
|
|
| 50 |
return []
|
| 51 |
tokenized = query.split()
|
| 52 |
try:
|
| 53 |
+
import numpy as np # Ensure np is defined here
|
| 54 |
scores = self.bm25.get_scores(tokenized)
|
| 55 |
top_indices = np.argsort(scores)[::-1][:top_k]
|
| 56 |
return [self.chunks[i] for i in top_indices]
|
src/utils.py
CHANGED
|
@@ -7,11 +7,6 @@ from typing import List
|
|
| 7 |
from openai import AzureOpenAI
|
| 8 |
from langchain_openai import AzureOpenAIEmbeddings
|
| 9 |
|
| 10 |
-
try:
|
| 11 |
-
from src.utils import logger
|
| 12 |
-
except ImportError:
|
| 13 |
-
import structlog
|
| 14 |
-
logger = structlog.get_logger()
|
| 15 |
|
| 16 |
class LLMClient:
|
| 17 |
"""
|
|
@@ -26,7 +21,7 @@ class LLMClient:
|
|
| 26 |
openai_model_name = model or os.getenv('OPENAI_MODEL', 'gpt-4o')
|
| 27 |
|
| 28 |
if not (azure_api_key or azure_endpoint or azure_api_version or openai_model_name):
|
| 29 |
-
|
| 30 |
raise EnvironmentError('Missing OPENAI_API_KEY')
|
| 31 |
client = AzureOpenAI(
|
| 32 |
api_key=azure_api_key,
|
|
@@ -45,7 +40,7 @@ class LLMClient:
|
|
| 45 |
text = resp.choices[0].message.content.strip()
|
| 46 |
return text
|
| 47 |
except Exception as e:
|
| 48 |
-
|
| 49 |
raise
|
| 50 |
|
| 51 |
|
|
|
|
| 7 |
from openai import AzureOpenAI
|
| 8 |
from langchain_openai import AzureOpenAIEmbeddings
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
class LLMClient:
|
| 12 |
"""
|
|
|
|
| 21 |
openai_model_name = model or os.getenv('OPENAI_MODEL', 'gpt-4o')
|
| 22 |
|
| 23 |
if not (azure_api_key or azure_endpoint or azure_api_version or openai_model_name):
|
| 24 |
+
print('OPENAI_API_KEY is not set')
|
| 25 |
raise EnvironmentError('Missing OPENAI_API_KEY')
|
| 26 |
client = AzureOpenAI(
|
| 27 |
api_key=azure_api_key,
|
|
|
|
| 40 |
text = resp.choices[0].message.content.strip()
|
| 41 |
return text
|
| 42 |
except Exception as e:
|
| 43 |
+
print('LLM generation failed')
|
| 44 |
raise
|
| 45 |
|
| 46 |
|