|
|
import gradio as gr |
|
|
import json |
|
|
import tempfile |
|
|
import pickle |
|
|
import os |
|
|
import cv2 |
|
|
import pandas as pd |
|
|
import requests |
|
|
import re |
|
|
from symspellpy import SymSpell, Verbosity |
|
|
from rapidocr import RapidOCR, EngineType, LangCls, LangDet, LangRec, ModelType, OCRVersion |
|
|
|
|
|
|
|
|
ANCHOR_PREFIXES = ["tab", "cap", "t."] |
|
|
|
|
|
|
|
|
ANCHORS = [ |
|
|
r"tab\.?", r"cap\.?", r"inj\.?", r"syp\.?", r"syr\.?", |
|
|
r"sol\.?", r"susp\.?", r"oint\.?", r"crm\.?", r"gel\.?", |
|
|
r"drops?", r"powder", r"dragees?", r"t\.?", r"c\.?" |
|
|
] |
|
|
ANCHOR_PATTERN = re.compile(r"\b(" + "|".join(ANCHORS) + r")", re.IGNORECASE) |
|
|
|
|
|
|
|
|
NON_MED_PATTERNS = [ |
|
|
r"emergency", r"contact", r"please", |
|
|
r"nephrologist", r"cardiologist", |
|
|
r"opinion", r"inform", r"kftafter", r"prescription", |
|
|
r"follow[- ]up", r"dr\.", r"physician", r"clinic", |
|
|
r"hospital", r"diagnosed", r"treatment", r"patient", |
|
|
r"age[: ]", r"sex[: ]", r"weight[: ]", r"height[: ]", |
|
|
r"bp[: ]", r"pulse[: ]", r"temperature[: ]", |
|
|
r"investigation", r"advised", r"admission", r"discharge", |
|
|
r"report", r"lab[: ]", r"laboratory", r"radiology", |
|
|
r"address", r"phone[: ]", r"mobile[: ]", r"email[: ]", |
|
|
r"signature", r"regd\.?", r"drugs? prescribed" |
|
|
] |
|
|
NON_MED_REGEX = re.compile("|".join(NON_MED_PATTERNS), re.IGNORECASE) |
|
|
|
|
|
|
|
|
rescue_list = {"d3", "b12", "k2", "iron", "zinc", "calcium", "vit", "xl"} |
|
|
|
|
|
def is_potential_med_line(text: str) -> bool: |
|
|
t = text.lower() |
|
|
non_med_match = NON_MED_REGEX.search(t) |
|
|
if non_med_match: |
|
|
return False |
|
|
anchor_match = ANCHOR_PATTERN.search(t) |
|
|
if not anchor_match: |
|
|
return False |
|
|
digit_match = re.search(r"\d", t) |
|
|
if not digit_match: |
|
|
return False |
|
|
return True |
|
|
|
|
|
def validate_drug_match(term: str, drug_db, drug_token_index): |
|
|
""" |
|
|
Map SymSpell term -> canonical database drug, or None if noise. |
|
|
""" |
|
|
if term in drug_db: |
|
|
return term |
|
|
if term in drug_token_index: |
|
|
|
|
|
return sorted(drug_token_index[term])[0] |
|
|
return None |
|
|
|
|
|
def normalize_anchored_tokens(raw_text: str): |
|
|
""" |
|
|
Use TAB/CAP/T. as anchors, not something to delete: |
|
|
- 'TABCLOPITAB75MG TAB' -> ['clopitab'] |
|
|
- 'TAB SOBISISTAB' -> ['sobisistab'] |
|
|
- 'TABSTARPRESSXL25MGTAB' -> ['starpressxl'] |
|
|
""" |
|
|
t = raw_text.lower() |
|
|
|
|
|
t = re.sub(r"\d+\s*(mg|ml|gm|%|u|mcg)", " ", t) |
|
|
t = re.sub(r"\d+", " ", t) |
|
|
tokens = t.split() |
|
|
|
|
|
normalized = [] |
|
|
skip_next = False |
|
|
|
|
|
for i, tok in enumerate(tokens): |
|
|
if skip_next: |
|
|
skip_next = False |
|
|
continue |
|
|
|
|
|
base = tok |
|
|
|
|
|
|
|
|
for pref in ANCHOR_PREFIXES: |
|
|
if base.startswith(pref) and len(base) > len(pref): |
|
|
base = base[len(pref):] |
|
|
break |
|
|
|
|
|
|
|
|
if base in ["tab", "cap", "t"]: |
|
|
if i + 1 < len(tokens): |
|
|
merged = tokens[i + 1] |
|
|
for pref in ANCHOR_PREFIXES: |
|
|
if merged.startswith(pref) and len(merged) > len(pref): |
|
|
merged = merged[len(pref):] |
|
|
break |
|
|
base = merged |
|
|
skip_next = True |
|
|
else: |
|
|
continue |
|
|
|
|
|
base = base.strip() |
|
|
if len(base) >= 3: |
|
|
normalized.append(base) |
|
|
|
|
|
return normalized |
|
|
|
|
|
def initialize_database(): |
|
|
data_path = os.path.join(os.path.dirname(__file__), "data/Dataset.csv") |
|
|
df = pd.read_csv(data_path) |
|
|
drug_db = set(df["Combined_Drugs"].astype(str).str.lower().str.strip()) |
|
|
sym_spell = SymSpell(max_dictionary_edit_distance=2, prefix_length=7) |
|
|
|
|
|
for drug in drug_db: |
|
|
d = drug.lower() |
|
|
sym_spell.create_dictionary_entry(d, 100000) |
|
|
parts = d.split() |
|
|
if len(parts) > 1: |
|
|
for p in parts: |
|
|
if len(p) > 3: |
|
|
sym_spell.create_dictionary_entry(p, 100000) |
|
|
|
|
|
drug_token_index = {} |
|
|
for full in drug_db: |
|
|
toks = full.split() |
|
|
for tok in toks: |
|
|
if len(tok) < 3: |
|
|
continue |
|
|
drug_token_index.setdefault(tok, set()).add(full) |
|
|
|
|
|
|
|
|
try: |
|
|
url = ( |
|
|
"https://raw.githubusercontent.com/first20hours/" |
|
|
"google-10000-english/master/google-10000-english-no-swears.txt" |
|
|
) |
|
|
response = requests.get(url, timeout=10) |
|
|
english_vocab = set(response.text.split()) |
|
|
except Exception: |
|
|
english_vocab = {"the", "and", "tab", "cap", "mg", "ml"} |
|
|
|
|
|
return { |
|
|
'drug_db': drug_db, |
|
|
'sym_spell': sym_spell, |
|
|
'drug_token_index': drug_token_index, |
|
|
'english_vocab': english_vocab, |
|
|
'rescue_list': rescue_list, |
|
|
'NON_MED_REGEX': NON_MED_REGEX, |
|
|
'ANCHOR_PATTERN': ANCHOR_PATTERN, |
|
|
'ANCHOR_PREFIXES': ANCHOR_PREFIXES |
|
|
} |
|
|
|
|
|
def process_image_ocr(image_path): |
|
|
|
|
|
cache_path = os.path.join(os.path.dirname(__file__), "cache/database_cache.pkl") |
|
|
try: |
|
|
with open(cache_path, 'rb') as f: |
|
|
cache = pickle.load(f) |
|
|
drug_db = cache['drug_db'] |
|
|
sym_spell = cache['sym_spell'] |
|
|
drug_token_index = cache['drug_token_index'] |
|
|
english_vocab = cache['english_vocab'] |
|
|
rescue_list = cache['rescue_list'] |
|
|
except FileNotFoundError: |
|
|
print("Error: database_cache.pkl not found. Initializing database...") |
|
|
cache = initialize_database() |
|
|
drug_db = cache['drug_db'] |
|
|
sym_spell = cache['sym_spell'] |
|
|
drug_token_index = cache['drug_token_index'] |
|
|
english_vocab = cache['english_vocab'] |
|
|
rescue_list = cache['rescue_list'] |
|
|
|
|
|
|
|
|
img = cv2.imread(image_path) |
|
|
if img is None: |
|
|
raise ValueError(f"Could not load image from {image_path}") |
|
|
|
|
|
|
|
|
ocr_engine = RapidOCR( |
|
|
params={ |
|
|
"Global.max_side_len": 2000, |
|
|
"Det.engine_type": EngineType.ONNXRUNTIME, |
|
|
"Det.lang_type": LangDet.CH, |
|
|
"Det.model_type": ModelType.MOBILE, |
|
|
"Det.ocr_version": OCRVersion.PPOCRV4, |
|
|
"Cls.engine_type": EngineType.ONNXRUNTIME, |
|
|
"Cls.lang_type": LangCls.CH, |
|
|
"Cls.model_type": ModelType.MOBILE, |
|
|
"Cls.ocr_version": OCRVersion.PPOCRV4, |
|
|
"Rec.engine_type": EngineType.ONNXRUNTIME, |
|
|
"Rec.lang_type": LangRec.CH, |
|
|
"Rec.model_type": ModelType.MOBILE, |
|
|
"Rec.ocr_version": OCRVersion.PPOCRV4, |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
ocr_result = ocr_engine( |
|
|
img, |
|
|
use_det=True, |
|
|
use_cls=True, |
|
|
use_rec=True, |
|
|
text_score=0.5, |
|
|
box_thresh=0.5, |
|
|
unclip_ratio=1.6, |
|
|
return_word_box=False, |
|
|
) |
|
|
|
|
|
ocr_data = ocr_result.txts |
|
|
|
|
|
found_meds_with_originals = {} |
|
|
|
|
|
for item in ocr_data: |
|
|
text_lower = item.lower() |
|
|
|
|
|
|
|
|
if not is_potential_med_line(text_lower): |
|
|
continue |
|
|
|
|
|
|
|
|
if "dr." in text_lower or "dr " in text_lower: |
|
|
continue |
|
|
|
|
|
|
|
|
candidate_tokens = normalize_anchored_tokens(item) |
|
|
|
|
|
|
|
|
if candidate_tokens: |
|
|
segmentation = sym_spell.word_segmentation(" ".join(candidate_tokens)) |
|
|
corrected_string = segmentation.corrected_string |
|
|
candidate_tokens = corrected_string.split() |
|
|
|
|
|
for word in candidate_tokens: |
|
|
if len(word) < 3: |
|
|
continue |
|
|
|
|
|
if word in english_vocab and word not in rescue_list: |
|
|
continue |
|
|
|
|
|
|
|
|
canonical = validate_drug_match(word, drug_db, drug_token_index) |
|
|
if canonical: |
|
|
if canonical not in found_meds_with_originals: |
|
|
found_meds_with_originals[canonical] = [] |
|
|
if item not in found_meds_with_originals[canonical]: |
|
|
found_meds_with_originals[canonical].append(item) |
|
|
continue |
|
|
|
|
|
suggestions = sym_spell.lookup( |
|
|
word, Verbosity.CLOSEST, max_edit_distance=1 |
|
|
) |
|
|
if not suggestions: |
|
|
continue |
|
|
|
|
|
cand = suggestions[0].term |
|
|
canonical = validate_drug_match(cand, drug_db, drug_token_index) |
|
|
if not canonical: |
|
|
continue |
|
|
|
|
|
if canonical not in found_meds_with_originals: |
|
|
found_meds_with_originals[canonical] = [] |
|
|
if item not in found_meds_with_originals[canonical]: |
|
|
found_meds_with_originals[canonical].append(item) |
|
|
|
|
|
print("\nJSON Output:") |
|
|
print(json.dumps(found_meds_with_originals, indent=4)) |
|
|
|
|
|
return found_meds_with_originals |
|
|
|
|
|
def process_prescription(image): |
|
|
if image is None: |
|
|
return "No image uploaded." |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp: |
|
|
image.save(tmp.name) |
|
|
result = process_image_ocr(tmp.name) |
|
|
return json.dumps(result, indent=4) |
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=process_prescription, |
|
|
inputs=gr.Image(type="pil", label="Upload Prescription Image"), |
|
|
outputs=gr.Textbox(label="Extracted Drugs", lines=20), |
|
|
title="MediBot - Drug Extraction from Prescriptions", |
|
|
description="Upload a prescription image to extract drug information." |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
iface.launch() |