Spaces:
Sleeping
Sleeping
Update rag_dspy.py
Browse files- rag_dspy.py +18 -5
rag_dspy.py
CHANGED
|
@@ -3,6 +3,8 @@
|
|
| 3 |
import dspy
|
| 4 |
from dspy_qdrant import QdrantRM
|
| 5 |
from qdrant_client import QdrantClient, models
|
|
|
|
|
|
|
| 6 |
from dotenv import load_dotenv
|
| 7 |
import os
|
| 8 |
|
|
@@ -10,7 +12,7 @@ load_dotenv()
|
|
| 10 |
# DSPy setup
|
| 11 |
lm = dspy.LM("gpt-4", max_tokens=512,api_key=os.environ.get("OPENAI_API_KEY"))
|
| 12 |
client = QdrantClient(url=os.environ.get("QDRANT_CLOUD_URL"), api_key=os.environ.get("QDRANT_API_KEY"))
|
| 13 |
-
collection_name = "
|
| 14 |
rm = QdrantRM(
|
| 15 |
qdrant_collection_name=collection_name,
|
| 16 |
qdrant_client=client,
|
|
@@ -22,7 +24,7 @@ dspy.settings.configure(lm=lm, rm=rm)
|
|
| 22 |
|
| 23 |
# Manual reranker using ColBERT multivector field
|
| 24 |
# Manual reranker using Qdrant’s native prefetch + ColBERT query
|
| 25 |
-
def rerank_with_colbert(query_text):
|
| 26 |
from fastembed import TextEmbedding, LateInteractionTextEmbedding
|
| 27 |
|
| 28 |
# Encode query once with both models
|
|
@@ -42,7 +44,14 @@ def rerank_with_colbert(query_text):
|
|
| 42 |
query=colbert_query,
|
| 43 |
using="colbert",
|
| 44 |
limit=5,
|
| 45 |
-
with_payload=True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
)
|
| 47 |
|
| 48 |
points = results.points
|
|
@@ -56,6 +65,8 @@ def rerank_with_colbert(query_text):
|
|
| 56 |
# DSPy Signature and Module
|
| 57 |
class MedicalAnswer(dspy.Signature):
|
| 58 |
question = dspy.InputField(desc="The medical question to answer")
|
|
|
|
|
|
|
| 59 |
context = dspy.OutputField(desc="The answer to the medical question")
|
| 60 |
final_answer = dspy.OutputField(desc="The answer to the medical question")
|
| 61 |
|
|
@@ -63,12 +74,14 @@ class MedicalRAG(dspy.Module):
|
|
| 63 |
def __init__(self):
|
| 64 |
super().__init__()
|
| 65 |
|
| 66 |
-
def forward(self, question):
|
| 67 |
-
reranked_docs = rerank_with_colbert(question)
|
| 68 |
|
| 69 |
context_str = "\n".join(reranked_docs)
|
| 70 |
|
| 71 |
return dspy.ChainOfThought(MedicalAnswer)(
|
| 72 |
question=question,
|
|
|
|
|
|
|
| 73 |
context=context_str
|
| 74 |
)
|
|
|
|
| 3 |
import dspy
|
| 4 |
from dspy_qdrant import QdrantRM
|
| 5 |
from qdrant_client import QdrantClient, models
|
| 6 |
+
from qdrant_client.models import Filter, FieldCondition, MatchValue
|
| 7 |
+
|
| 8 |
from dotenv import load_dotenv
|
| 9 |
import os
|
| 10 |
|
|
|
|
| 12 |
# DSPy setup
|
| 13 |
lm = dspy.LM("gpt-4", max_tokens=512,api_key=os.environ.get("OPENAI_API_KEY"))
|
| 14 |
client = QdrantClient(url=os.environ.get("QDRANT_CLOUD_URL"), api_key=os.environ.get("QDRANT_API_KEY"))
|
| 15 |
+
collection_name = "indexed_medical_chat_bot"
|
| 16 |
rm = QdrantRM(
|
| 17 |
qdrant_collection_name=collection_name,
|
| 18 |
qdrant_client=client,
|
|
|
|
| 24 |
|
| 25 |
# Manual reranker using ColBERT multivector field
|
| 26 |
# Manual reranker using Qdrant’s native prefetch + ColBERT query
|
| 27 |
+
def rerank_with_colbert(query_text, year, specialty):
|
| 28 |
from fastembed import TextEmbedding, LateInteractionTextEmbedding
|
| 29 |
|
| 30 |
# Encode query once with both models
|
|
|
|
| 44 |
query=colbert_query,
|
| 45 |
using="colbert",
|
| 46 |
limit=5,
|
| 47 |
+
with_payload=True,
|
| 48 |
+
query_filter=Filter(
|
| 49 |
+
must=[
|
| 50 |
+
FieldCondition(key="specialty", match=MatchValue(value=specialty)),
|
| 51 |
+
FieldCondition(key="year", match=MatchValue(value=year))
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
)
|
| 55 |
)
|
| 56 |
|
| 57 |
points = results.points
|
|
|
|
| 65 |
# DSPy Signature and Module
|
| 66 |
class MedicalAnswer(dspy.Signature):
|
| 67 |
question = dspy.InputField(desc="The medical question to answer")
|
| 68 |
+
year = dspy.InputField(desc="The year of the medical paper")
|
| 69 |
+
specialty = dspy.InputField(desc="The specialty of the medical paper")
|
| 70 |
context = dspy.OutputField(desc="The answer to the medical question")
|
| 71 |
final_answer = dspy.OutputField(desc="The answer to the medical question")
|
| 72 |
|
|
|
|
| 74 |
def __init__(self):
|
| 75 |
super().__init__()
|
| 76 |
|
| 77 |
+
def forward(self, question, year, specialty):
|
| 78 |
+
reranked_docs = rerank_with_colbert(question, year, specialty)
|
| 79 |
|
| 80 |
context_str = "\n".join(reranked_docs)
|
| 81 |
|
| 82 |
return dspy.ChainOfThought(MedicalAnswer)(
|
| 83 |
question=question,
|
| 84 |
+
year=year,
|
| 85 |
+
specialty=specialty,
|
| 86 |
context=context_str
|
| 87 |
)
|