IFMedTechdemo's picture
updated the regax of anchor prefix
450410c verified
raw
history blame
7.95 kB
import gradio as gr
import pickle
import os
import cv2
import pandas as pd
import re
from symspellpy import SymSpell, Verbosity
from rapidocr import RapidOCR, EngineType, LangDet, LangRec, ModelType, OCRVersion
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Constants - Match both separated and merged prefixes
ANCHOR_PATTERN = re.compile(
r"\b(tab\.?|cap\.?|t\.?|inj\.?|syp\.?)\s*([a-zA-Z0-9\s]+?)(?:\s+\d+)?(?:\s+(?:mg|ml))?(?=\s*(?:tab|cap|inj|syp|t\.|\d+|$))",
re.IGNORECASE
)
# ============================================================================
# GLOBAL SINGLETONS
# ============================================================================
_ocr_engine = None
_drug_db = None
_sym_spell = None
_cache_path = os.path.join(os.path.dirname(__file__), "cache", "database_cache.pkl")
def ensure_cache_dir():
"""Ensure cache directory exists."""
cache_dir = os.path.dirname(_cache_path)
if not os.path.exists(cache_dir):
os.makedirs(cache_dir, exist_ok=True)
def initialize_database():
"""
Load drug database and SymSpell once.
Uses cache if available to skip expensive recomputation.
"""
global _drug_db, _sym_spell
ensure_cache_dir()
# Try to load from cache
if os.path.exists(_cache_path):
logger.info("Loading database from cache...")
try:
with open(_cache_path, 'rb') as f:
cache_data = pickle.load(f)
_drug_db = cache_data['drug_db']
_sym_spell = cache_data['sym_spell']
logger.info(f"Cache loaded: {len(_drug_db)} drugs")
return
except Exception as e:
logger.warning(f"Cache load failed: {e}. Recomputing...")
# Compute from scratch
logger.info("Initializing database...")
_drug_db = {}
try:
df = pd.read_csv("Dataset.csv")
for idx, row in df.iterrows():
drug_name = str(row.get('drug_name', '')).strip().lower()
if drug_name:
_drug_db[drug_name] = True
except Exception as e:
logger.warning(f"Dataset loading failed: {e}. Using minimal DB.")
_drug_db = {"aspirin": True, "paracetamol": True}
# Initialize SymSpell with drug DB
_sym_spell = SymSpell(max_dictionary_edit_distance=1)
for drug in _drug_db:
_sym_spell.create_dictionary_entry(drug, 1000)
# Cache for next startup
try:
cache_data = {
'drug_db': _drug_db,
'sym_spell': _sym_spell
}
with open(_cache_path, 'wb') as f:
pickle.dump(cache_data, f)
logger.info(f"Database cached: {len(_drug_db)} drugs")
except Exception as e:
logger.warning(f"Cache save failed: {e}")
def get_ocr_engine():
"""Get or create the RapidOCR engine with MOBILE + ONNX optimization."""
global _ocr_engine
if _ocr_engine is None:
logger.info("Initializing RapidOCR engine with MOBILE models...")
_ocr_engine = RapidOCR(
params={
"Global.max_side_len": 1280,
"Det.engine_type": EngineType.ONNXRUNTIME,
"Det.lang_type": LangDet.CH,
"Det.model_type": ModelType.MOBILE,
"Det.ocr_version": OCRVersion.PPOCRV5,
"Rec.engine_type": EngineType.ONNXRUNTIME,
"Rec.lang_type": LangRec.CH,
"Rec.model_type": ModelType.MOBILE,
"Rec.ocr_version": OCRVersion.PPOCRV5,
}
)
return _ocr_engine
def validate_drug_match(term: str) -> str:
"""Map term to canonical database drug, or None if noise."""
term = term.lower()
if term in _drug_db:
return term
# Skip SymSpell for very short or long words
if len(term) < 3 or len(term) > 15:
return None
# Fuzzy match via SymSpell
suggestions = _sym_spell.lookup(term, Verbosity.CLOSEST, max_edit_distance=1)
if suggestions and suggestions[0].term in _drug_db:
return suggestions[0].term
return None
def extract_drugs_from_line(line_text: str):
"""
Extract drug names that follow or are merged with tab/cap/t prefixes.
"""
drugs = []
matches = ANCHOR_PATTERN.finditer(line_text)
for match in matches:
prefix = match.group(1).lower()
next_word = match.group(2)
canonical = validate_drug_match(next_word)
if canonical:
drugs.append(canonical)
return drugs
def process_image_ocr(image):
"""Fast OCR + drug extraction using MOBILE models and ONNX runtime."""
logger.info("Processing image...")
ocr_engine = get_ocr_engine()
# Preprocess: resize if too large
height, width = image.shape[:2]
if width > 1280:
scale = 1280 / width
image = cv2.resize(image, None, fx=scale, fy=scale)
# Run OCR with optimized settings
try:
ocr_result = ocr_engine(
image,
use_det=True,
use_cls=False,
use_rec=True,
)
except Exception as e:
logger.error(f"OCR failed: {e}")
return {"error": str(e)}
# Handle new RapidOCR return format
if not ocr_result or not hasattr(ocr_result, 'txts'):
return {"drugs": [], "raw_lines": []}
# Extract drugs
drugs_found = set()
raw_lines = []
for line_text in ocr_result.txts:
if not line_text:
continue
raw_lines.append(line_text)
line_drugs = extract_drugs_from_line(line_text)
drugs_found.update(line_drugs)
return {
"drugs": sorted(list(drugs_found)),
"raw_lines": raw_lines,
"drugs_count": len(drugs_found),
"elapse": f"{ocr_result.elapse:.3f}s"
}
def process_input(image_input):
"""Gradio interface handler."""
if image_input is None:
return "Please upload an image.", {}
try:
# Convert Gradio RGB to BGR for OpenCV
image = cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR)
result = process_image_ocr(image)
if "error" in result:
return f"Error: {result['error']}", {}
# Summary text
summary = f"Found {result['drugs_count']} medication(s) in {result.get('elapse', 'N/A')}"
# JSON output with all medications
medications_json = {
"total_medications": result["drugs_count"],
"processing_time": result.get("elapse", "N/A"),
"medications": [
{
"id": idx + 1,
"name": drug.title()
}
for idx, drug in enumerate(result["drugs"])
]
}
return summary, medications_json
except Exception as e:
logger.error(f"Processing error: {e}")
return f"Error: {str(e)}", {}
# ============================================================================
# Gradio Interface
# ============================================================================
logger.info("Starting Medibot...")
initialize_database()
logger.info("Database initialized. Ready for inference.")
with gr.Blocks(title="Medibot - Fast OCR") as demo:
gr.Markdown("# Medibot: Prescription OCR")
gr.Markdown("Upload a prescription image to extract medications.")
gr.Markdown("**Optimized with MOBILE models + ONNX Runtime for maximum speed**")
with gr.Row():
image_input = gr.Image(type="numpy", label="Upload Prescription")
output_text = gr.Textbox(label="Summary", lines=2)
with gr.Row():
medications_json = gr.JSON(label="Extracted Medications")
submit_btn = gr.Button("Extract Medications", variant="primary")
submit_btn.click(process_input, inputs=image_input, outputs=[output_text, medications_json])
if __name__ == "__main__":
demo.queue(max_size=10)
demo.launch(max_threads=4, show_error=True)