|
|
import gradio as gr |
|
|
import pickle |
|
|
import os |
|
|
import cv2 |
|
|
import pandas as pd |
|
|
import re |
|
|
from symspellpy import SymSpell, Verbosity |
|
|
from rapidocr import RapidOCR, EngineType, LangDet, LangRec, ModelType, OCRVersion |
|
|
import logging |
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
ANCHOR_PATTERN = re.compile( |
|
|
r"\b(tab\.?|cap\.?|t\.?|inj\.?|syp\.?)\s*([a-zA-Z0-9\s]+?)(?:\s+\d+)?(?:\s+(?:mg|ml))?(?=\s*(?:tab|cap|inj|syp|t\.|\d+|$))", |
|
|
re.IGNORECASE |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_ocr_engine = None |
|
|
_drug_db = None |
|
|
_sym_spell = None |
|
|
_cache_path = os.path.join(os.path.dirname(__file__), "cache", "database_cache.pkl") |
|
|
|
|
|
|
|
|
def ensure_cache_dir(): |
|
|
"""Ensure cache directory exists.""" |
|
|
cache_dir = os.path.dirname(_cache_path) |
|
|
if not os.path.exists(cache_dir): |
|
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
def initialize_database(): |
|
|
""" |
|
|
Load drug database and SymSpell once. |
|
|
Uses cache if available to skip expensive recomputation. |
|
|
""" |
|
|
global _drug_db, _sym_spell |
|
|
|
|
|
ensure_cache_dir() |
|
|
|
|
|
|
|
|
if os.path.exists(_cache_path): |
|
|
logger.info("Loading database from cache...") |
|
|
try: |
|
|
with open(_cache_path, 'rb') as f: |
|
|
cache_data = pickle.load(f) |
|
|
_drug_db = cache_data['drug_db'] |
|
|
_sym_spell = cache_data['sym_spell'] |
|
|
logger.info(f"Cache loaded: {len(_drug_db)} drugs") |
|
|
return |
|
|
except Exception as e: |
|
|
logger.warning(f"Cache load failed: {e}. Recomputing...") |
|
|
|
|
|
|
|
|
logger.info("Initializing database...") |
|
|
_drug_db = {} |
|
|
|
|
|
try: |
|
|
df = pd.read_csv("Dataset.csv") |
|
|
for idx, row in df.iterrows(): |
|
|
drug_name = str(row.get('drug_name', '')).strip().lower() |
|
|
if drug_name: |
|
|
_drug_db[drug_name] = True |
|
|
except Exception as e: |
|
|
logger.warning(f"Dataset loading failed: {e}. Using minimal DB.") |
|
|
_drug_db = {"aspirin": True, "paracetamol": True} |
|
|
|
|
|
|
|
|
_sym_spell = SymSpell(max_dictionary_edit_distance=1) |
|
|
for drug in _drug_db: |
|
|
_sym_spell.create_dictionary_entry(drug, 1000) |
|
|
|
|
|
|
|
|
try: |
|
|
cache_data = { |
|
|
'drug_db': _drug_db, |
|
|
'sym_spell': _sym_spell |
|
|
} |
|
|
with open(_cache_path, 'wb') as f: |
|
|
pickle.dump(cache_data, f) |
|
|
logger.info(f"Database cached: {len(_drug_db)} drugs") |
|
|
except Exception as e: |
|
|
logger.warning(f"Cache save failed: {e}") |
|
|
|
|
|
|
|
|
def get_ocr_engine(): |
|
|
"""Get or create the RapidOCR engine with MOBILE + ONNX optimization.""" |
|
|
global _ocr_engine |
|
|
if _ocr_engine is None: |
|
|
logger.info("Initializing RapidOCR engine with MOBILE models...") |
|
|
_ocr_engine = RapidOCR( |
|
|
params={ |
|
|
"Global.max_side_len": 1280, |
|
|
"Det.engine_type": EngineType.ONNXRUNTIME, |
|
|
"Det.lang_type": LangDet.CH, |
|
|
"Det.model_type": ModelType.MOBILE, |
|
|
"Det.ocr_version": OCRVersion.PPOCRV5, |
|
|
"Rec.engine_type": EngineType.ONNXRUNTIME, |
|
|
"Rec.lang_type": LangRec.CH, |
|
|
"Rec.model_type": ModelType.MOBILE, |
|
|
"Rec.ocr_version": OCRVersion.PPOCRV5, |
|
|
} |
|
|
) |
|
|
return _ocr_engine |
|
|
|
|
|
|
|
|
def validate_drug_match(term: str) -> str: |
|
|
"""Map term to canonical database drug, or None if noise.""" |
|
|
term = term.lower() |
|
|
|
|
|
if term in _drug_db: |
|
|
return term |
|
|
|
|
|
|
|
|
if len(term) < 3 or len(term) > 15: |
|
|
return None |
|
|
|
|
|
|
|
|
suggestions = _sym_spell.lookup(term, Verbosity.CLOSEST, max_edit_distance=1) |
|
|
if suggestions and suggestions[0].term in _drug_db: |
|
|
return suggestions[0].term |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
def extract_drugs_from_line(line_text: str): |
|
|
""" |
|
|
Extract drug names that follow or are merged with tab/cap/t prefixes. |
|
|
""" |
|
|
drugs = [] |
|
|
matches = ANCHOR_PATTERN.finditer(line_text) |
|
|
|
|
|
for match in matches: |
|
|
prefix = match.group(1).lower() |
|
|
next_word = match.group(2) |
|
|
canonical = validate_drug_match(next_word) |
|
|
if canonical: |
|
|
drugs.append(canonical) |
|
|
|
|
|
return drugs |
|
|
|
|
|
|
|
|
def process_image_ocr(image): |
|
|
"""Fast OCR + drug extraction using MOBILE models and ONNX runtime.""" |
|
|
logger.info("Processing image...") |
|
|
|
|
|
ocr_engine = get_ocr_engine() |
|
|
|
|
|
|
|
|
height, width = image.shape[:2] |
|
|
if width > 1280: |
|
|
scale = 1280 / width |
|
|
image = cv2.resize(image, None, fx=scale, fy=scale) |
|
|
|
|
|
|
|
|
try: |
|
|
ocr_result = ocr_engine( |
|
|
image, |
|
|
use_det=True, |
|
|
use_cls=False, |
|
|
use_rec=True, |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"OCR failed: {e}") |
|
|
return {"error": str(e)} |
|
|
|
|
|
|
|
|
if not ocr_result or not hasattr(ocr_result, 'txts'): |
|
|
return {"drugs": [], "raw_lines": []} |
|
|
|
|
|
|
|
|
drugs_found = set() |
|
|
raw_lines = [] |
|
|
|
|
|
for line_text in ocr_result.txts: |
|
|
if not line_text: |
|
|
continue |
|
|
|
|
|
raw_lines.append(line_text) |
|
|
line_drugs = extract_drugs_from_line(line_text) |
|
|
drugs_found.update(line_drugs) |
|
|
|
|
|
return { |
|
|
"drugs": sorted(list(drugs_found)), |
|
|
"raw_lines": raw_lines, |
|
|
"drugs_count": len(drugs_found), |
|
|
"elapse": f"{ocr_result.elapse:.3f}s" |
|
|
} |
|
|
|
|
|
|
|
|
def process_input(image_input): |
|
|
"""Gradio interface handler.""" |
|
|
if image_input is None: |
|
|
return "Please upload an image.", {} |
|
|
|
|
|
try: |
|
|
|
|
|
image = cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR) |
|
|
result = process_image_ocr(image) |
|
|
|
|
|
if "error" in result: |
|
|
return f"Error: {result['error']}", {} |
|
|
|
|
|
|
|
|
summary = f"Found {result['drugs_count']} medication(s) in {result.get('elapse', 'N/A')}" |
|
|
|
|
|
|
|
|
medications_json = { |
|
|
"total_medications": result["drugs_count"], |
|
|
"processing_time": result.get("elapse", "N/A"), |
|
|
"medications": [ |
|
|
{ |
|
|
"id": idx + 1, |
|
|
"name": drug.title() |
|
|
} |
|
|
for idx, drug in enumerate(result["drugs"]) |
|
|
] |
|
|
} |
|
|
|
|
|
return summary, medications_json |
|
|
except Exception as e: |
|
|
logger.error(f"Processing error: {e}") |
|
|
return f"Error: {str(e)}", {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("Starting Medibot...") |
|
|
initialize_database() |
|
|
logger.info("Database initialized. Ready for inference.") |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Medibot - Fast OCR") as demo: |
|
|
gr.Markdown("# Medibot: Prescription OCR") |
|
|
gr.Markdown("Upload a prescription image to extract medications.") |
|
|
gr.Markdown("**Optimized with MOBILE models + ONNX Runtime for maximum speed**") |
|
|
|
|
|
with gr.Row(): |
|
|
image_input = gr.Image(type="numpy", label="Upload Prescription") |
|
|
output_text = gr.Textbox(label="Summary", lines=2) |
|
|
|
|
|
with gr.Row(): |
|
|
medications_json = gr.JSON(label="Extracted Medications") |
|
|
|
|
|
submit_btn = gr.Button("Extract Medications", variant="primary") |
|
|
submit_btn.click(process_input, inputs=image_input, outputs=[output_text, medications_json]) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue(max_size=10) |
|
|
demo.launch(max_threads=4, show_error=True) |