import gradio as gr import pickle import os import cv2 import pandas as pd import re from symspellpy import SymSpell, Verbosity from rapidocr import RapidOCR, EngineType, LangDet, LangRec, ModelType, OCRVersion import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Constants - Match both separated and merged prefixes ANCHOR_PATTERN = re.compile( r"\b(tab\.?|cap\.?|t\.?|inj\.?|syp\.?)\s*([a-zA-Z0-9\s]+?)(?:\s+\d+)?(?:\s+(?:mg|ml))?(?=\s*(?:tab|cap|inj|syp|t\.|\d+|$))", re.IGNORECASE ) # ============================================================================ # GLOBAL SINGLETONS # ============================================================================ _ocr_engine = None _drug_db = None _sym_spell = None _cache_path = os.path.join(os.path.dirname(__file__), "cache", "database_cache.pkl") def ensure_cache_dir(): """Ensure cache directory exists.""" cache_dir = os.path.dirname(_cache_path) if not os.path.exists(cache_dir): os.makedirs(cache_dir, exist_ok=True) def initialize_database(): """ Load drug database and SymSpell once. Uses cache if available to skip expensive recomputation. """ global _drug_db, _sym_spell ensure_cache_dir() # Try to load from cache if os.path.exists(_cache_path): logger.info("Loading database from cache...") try: with open(_cache_path, 'rb') as f: cache_data = pickle.load(f) _drug_db = cache_data['drug_db'] _sym_spell = cache_data['sym_spell'] logger.info(f"Cache loaded: {len(_drug_db)} drugs") return except Exception as e: logger.warning(f"Cache load failed: {e}. Recomputing...") # Compute from scratch logger.info("Initializing database...") _drug_db = {} try: df = pd.read_csv("Dataset.csv") for idx, row in df.iterrows(): drug_name = str(row.get('drug_name', '')).strip().lower() if drug_name: _drug_db[drug_name] = True except Exception as e: logger.warning(f"Dataset loading failed: {e}. Using minimal DB.") _drug_db = {"aspirin": True, "paracetamol": True} # Initialize SymSpell with drug DB _sym_spell = SymSpell(max_dictionary_edit_distance=1) for drug in _drug_db: _sym_spell.create_dictionary_entry(drug, 1000) # Cache for next startup try: cache_data = { 'drug_db': _drug_db, 'sym_spell': _sym_spell } with open(_cache_path, 'wb') as f: pickle.dump(cache_data, f) logger.info(f"Database cached: {len(_drug_db)} drugs") except Exception as e: logger.warning(f"Cache save failed: {e}") def get_ocr_engine(): """Get or create the RapidOCR engine with MOBILE + ONNX optimization.""" global _ocr_engine if _ocr_engine is None: logger.info("Initializing RapidOCR engine with MOBILE models...") _ocr_engine = RapidOCR( params={ "Global.max_side_len": 1280, "Det.engine_type": EngineType.ONNXRUNTIME, "Det.lang_type": LangDet.CH, "Det.model_type": ModelType.MOBILE, "Det.ocr_version": OCRVersion.PPOCRV5, "Rec.engine_type": EngineType.ONNXRUNTIME, "Rec.lang_type": LangRec.CH, "Rec.model_type": ModelType.MOBILE, "Rec.ocr_version": OCRVersion.PPOCRV5, } ) return _ocr_engine def validate_drug_match(term: str) -> str: """Map term to canonical database drug, or None if noise.""" term = term.lower() if term in _drug_db: return term # Skip SymSpell for very short or long words if len(term) < 3 or len(term) > 15: return None # Fuzzy match via SymSpell suggestions = _sym_spell.lookup(term, Verbosity.CLOSEST, max_edit_distance=1) if suggestions and suggestions[0].term in _drug_db: return suggestions[0].term return None def extract_drugs_from_line(line_text: str): """ Extract drug names that follow or are merged with tab/cap/t prefixes. """ drugs = [] matches = ANCHOR_PATTERN.finditer(line_text) for match in matches: prefix = match.group(1).lower() next_word = match.group(2) canonical = validate_drug_match(next_word) if canonical: drugs.append(canonical) return drugs def process_image_ocr(image): """Fast OCR + drug extraction using MOBILE models and ONNX runtime.""" logger.info("Processing image...") ocr_engine = get_ocr_engine() # Preprocess: resize if too large height, width = image.shape[:2] if width > 1280: scale = 1280 / width image = cv2.resize(image, None, fx=scale, fy=scale) # Run OCR with optimized settings try: ocr_result = ocr_engine( image, use_det=True, use_cls=False, use_rec=True, ) except Exception as e: logger.error(f"OCR failed: {e}") return {"error": str(e)} # Handle new RapidOCR return format if not ocr_result or not hasattr(ocr_result, 'txts'): return {"drugs": [], "raw_lines": []} # Extract drugs drugs_found = set() raw_lines = [] for line_text in ocr_result.txts: if not line_text: continue raw_lines.append(line_text) line_drugs = extract_drugs_from_line(line_text) drugs_found.update(line_drugs) return { "drugs": sorted(list(drugs_found)), "raw_lines": raw_lines, "drugs_count": len(drugs_found), "elapse": f"{ocr_result.elapse:.3f}s" } def process_input(image_input): """Gradio interface handler.""" if image_input is None: return "Please upload an image.", {} try: # Convert Gradio RGB to BGR for OpenCV image = cv2.cvtColor(image_input, cv2.COLOR_RGB2BGR) result = process_image_ocr(image) if "error" in result: return f"Error: {result['error']}", {} # Summary text summary = f"Found {result['drugs_count']} medication(s) in {result.get('elapse', 'N/A')}" # JSON output with all medications medications_json = { "total_medications": result["drugs_count"], "processing_time": result.get("elapse", "N/A"), "medications": [ { "id": idx + 1, "name": drug.title() } for idx, drug in enumerate(result["drugs"]) ] } return summary, medications_json except Exception as e: logger.error(f"Processing error: {e}") return f"Error: {str(e)}", {} # ============================================================================ # Gradio Interface # ============================================================================ logger.info("Starting Medibot...") initialize_database() logger.info("Database initialized. Ready for inference.") with gr.Blocks(title="Medibot - Fast OCR") as demo: gr.Markdown("# Medibot: Prescription OCR") gr.Markdown("Upload a prescription image to extract medications.") gr.Markdown("**Optimized with MOBILE models + ONNX Runtime for maximum speed**") with gr.Row(): image_input = gr.Image(type="numpy", label="Upload Prescription") output_text = gr.Textbox(label="Summary", lines=2) with gr.Row(): medications_json = gr.JSON(label="Extracted Medications") submit_btn = gr.Button("Extract Medications", variant="primary") submit_btn.click(process_input, inputs=image_input, outputs=[output_text, medications_json]) if __name__ == "__main__": demo.queue(max_size=10) demo.launch(max_threads=4, show_error=True)