Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import langextract as lx | |
| import json | |
| import re | |
| from typing import List, Dict, Any, Tuple, Optional | |
| import pandas as pd | |
| import requests | |
| import time | |
| import os | |
| from pathlib import Path | |
| import tempfile | |
| import torch | |
| import spaces | |
| # Global variables to store the loaded model and tokenizer | |
| dental_model = None | |
| dental_tokenizer = None | |
| current_token = None | |
| output_directory = Path(".") | |
| def load_dental_transformers_model(): | |
| """Load the dental model using transformers""" | |
| global dental_model, dental_tokenizer | |
| if dental_model is None or dental_tokenizer is None: | |
| try: | |
| print("Loading transformers model... This may take a moment on first run.") | |
| # Load tokenizer and model | |
| dental_tokenizer = AutoTokenizer.from_pretrained("yasserrmd/DentaInstruct-1.2B") | |
| dental_model = AutoModelForCausalLM.from_pretrained( | |
| "yasserrmd/DentaInstruct-1.2B", | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto" if torch.cuda.is_available() else None | |
| ) | |
| # Set pad token if not set | |
| if dental_tokenizer.pad_token is None: | |
| dental_tokenizer.pad_token = dental_tokenizer.eos_token | |
| print("Model loaded successfully!") | |
| return dental_model, dental_tokenizer | |
| except Exception as e: | |
| print(f"Error loading transformers model: {str(e)}") | |
| return None, None | |
| return dental_model, dental_tokenizer | |
| def generate_dental_response( | |
| question: str, | |
| max_tokens: int = 2048, | |
| temperature: float = 0.7 | |
| ) -> str: | |
| """Generate response using transformers model""" | |
| # Load model and tokenizer | |
| model, tokenizer = load_dental_transformers_model() | |
| if not model or not tokenizer: | |
| return "β Transformers model not available." | |
| try: | |
| system_prompt = """You are a dental AI assistant. When providing medication recommendations, you must: | |
| 1. Always provide a complete 3-day medication regimen | |
| 2. Include detailed descriptions for each medication including exact dosage amounts, frequency, duration, mechanism of action | |
| 3. Organize the response clearly with medication names, dosages, and instructions | |
| 4. Always include a disclaimer about professional medical consultation""" | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": question} | |
| ] | |
| # Apply chat template | |
| try: | |
| # Try with chat template first | |
| input_text = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=False | |
| ) | |
| except: | |
| # Fallback to simple concatenation if chat template fails | |
| input_text = f"{system_prompt}\n\nUser: {question}\n\nAssistant:" | |
| # Tokenize the input | |
| inputs = tokenizer( | |
| input_text, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=2048 | |
| ) | |
| # Remove token_type_ids if present (not needed for most models) | |
| if 'token_type_ids' in inputs: | |
| del inputs['token_type_ids'] | |
| # Move to device | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| # Generate response | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| input_ids=inputs['input_ids'], | |
| attention_mask=inputs['attention_mask'], | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=0.9, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id | |
| ) | |
| # Decode only the new tokens (response) | |
| response = tokenizer.decode( | |
| outputs[0][inputs['input_ids'].shape[-1]:], | |
| skip_special_tokens=True | |
| ) | |
| return response.strip() | |
| except Exception as e: | |
| return f"β Error generating response with transformers model: {str(e)}" | |
| def extract_medications(text: str, gemini_api_key: str = "") -> Tuple[str, str, str]: | |
| """Extract medication information from text""" | |
| try: | |
| # Check if API key is provided | |
| if not gemini_api_key or not gemini_api_key.strip(): | |
| return "β Please provide a valid Gemini API key for medication extraction.", text, "" | |
| model_api_key = gemini_api_key.strip() | |
| prompt_description = "Extract medication information including medication name, dosage, route, frequency, and duration in the order they appear in the text." | |
| examples = [ | |
| lx.data.ExampleData( | |
| text="Patient was given 250 mg IV Cefazolin TID for one week.", | |
| extractions=[ | |
| lx.data.Extraction(extraction_class="dosage", extraction_text="250 mg"), | |
| lx.data.Extraction(extraction_class="route", extraction_text="IV"), | |
| lx.data.Extraction(extraction_class="medication", extraction_text="Cefazolin"), | |
| lx.data.Extraction(extraction_class="frequency", extraction_text="TID"), | |
| lx.data.Extraction(extraction_class="duration", extraction_text="for one week") | |
| ] | |
| ) | |
| ] | |
| result = lx.extract( | |
| text_or_documents=text, | |
| prompt_description=prompt_description, | |
| examples=examples, | |
| model_id="gemini-2.0-flash-exp", | |
| api_key=model_api_key | |
| ) | |
| if result and result.extractions: | |
| # Create DataFrame for display | |
| extraction_data = [] | |
| for entity in result.extractions: | |
| position_info = "" | |
| if entity.char_interval: | |
| start, end = entity.char_interval.start_pos, entity.char_interval.end_pos | |
| position_info = f"{start}-{end}" | |
| extraction_data.append({ | |
| "Type": entity.extraction_class.capitalize(), | |
| "Text": entity.extraction_text, | |
| "Position": position_info | |
| }) | |
| df = pd.DataFrame(extraction_data) | |
| # Create highlighted text | |
| highlighted_text = highlight_text_with_extractions(text, result.extractions) | |
| # Save and visualize the results | |
| try: | |
| lx.io.save_annotated_documents([result], output_name="medical_ner_extraction.jsonl", output_dir=output_directory) | |
| # Generate the interactive visualization | |
| html_content = lx.visualize("medical_ner_extraction.jsonl") | |
| return df.to_string(index=False), highlighted_text, html_content | |
| except Exception as viz_error: | |
| # If visualization fails, still return the other results | |
| return df.to_string(index=False), highlighted_text, f"β οΈ Visualization generation failed: {str(viz_error)}" | |
| else: | |
| return "βΉοΈ No medications found in the text.", text, "" | |
| except Exception as e: | |
| return f"β Error extracting medications: {str(e)}", text, "" | |
| def highlight_text_with_extractions(text: str, extractions: List[Any]) -> str: | |
| """Highlight extracted entities in the original text""" | |
| if not extractions: | |
| return text | |
| # Sort extractions by position to avoid overlap issues | |
| sorted_extractions = sorted( | |
| [e for e in extractions if e.char_interval], | |
| key=lambda x: x.char_interval.start_pos | |
| ) | |
| highlighted_text = text | |
| offset = 0 | |
| for extraction in sorted_extractions: | |
| start = extraction.char_interval.start_pos + offset | |
| end = extraction.char_interval.end_pos + offset | |
| original = highlighted_text[start:end] | |
| highlighted = f'**[{extraction.extraction_class.upper()}]** {original} **[/{extraction.extraction_class.upper()}]**' | |
| highlighted_text = highlighted_text[:start] + highlighted + highlighted_text[end:] | |
| offset += len(highlighted) - len(original) | |
| return highlighted_text | |
| def dental_consultation_interface( | |
| question: str, | |
| max_tokens: int, | |
| temperature: float | |
| ) -> str: | |
| """Main interface for dental consultation""" | |
| if not question.strip(): | |
| return "Please enter a question first." | |
| response = generate_dental_response( | |
| question=question, | |
| max_tokens=max_tokens, | |
| temperature=temperature | |
| ) | |
| token_count = len(response.split()) | |
| return f"{response}\n\n---\nπ Response length: ~{token_count} words" | |
| def medication_extraction_interface(text: str, gemini_api_key: str) -> Tuple[str, str, str]: | |
| """Interface for medication extraction""" | |
| if not text.strip(): | |
| return "Please enter text for medication extraction.", "", "" | |
| return extract_medications(text, gemini_api_key) | |
| # Quick question options | |
| QUICK_QUESTIONS = [ | |
| "I have a toothache with throbbing pain, provide 3-day medication", | |
| "What causes tooth pain and how to treat it?", | |
| "How to prevent cavities?", | |
| "What are the signs of gum disease?", | |
| "Emergency dental care advice", | |
| "Post-extraction care instructions with medications", | |
| "Wisdom tooth pain relief medication regimen" | |
| ] | |
| def create_gradio_interface(): | |
| """Create the main Gradio interface""" | |
| # Custom CSS | |
| css = """ | |
| .gradio-container { | |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
| } | |
| .main-header { | |
| text-align: center; | |
| background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); | |
| color: white; | |
| padding: 2rem; | |
| border-radius: 10px; | |
| margin-bottom: 2rem; | |
| } | |
| .disclaimer { | |
| background-color: #fff3cd; | |
| border: 1px solid #ffeaa7; | |
| border-radius: 5px; | |
| padding: 1rem; | |
| margin: 1rem 0; | |
| color: #856404; | |
| } | |
| """ | |
| with gr.Blocks(css=css, title="π¦· Dental AI Assistant") as demo: | |
| # Header | |
| gr.HTML(""" | |
| <div class="main-header"> | |
| <h1>π¦· Dental AI Assistant</h1> | |
| <p>Advanced dental consultation and medication extraction</p> | |
| </div> | |
| """) | |
| with gr.Tabs(): | |
| # Tab 1: Dental Consultation | |
| with gr.TabItem("π¬ Dental Consultation"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| question_input = gr.Textbox( | |
| label="Ask your dental question:", | |
| placeholder="e.g., I have a toothache, what should I do?", | |
| lines=3 | |
| ) | |
| quick_question = gr.Dropdown( | |
| choices=[""] + QUICK_QUESTIONS, | |
| label="Or select a quick question:", | |
| value="" | |
| ) | |
| # Update question input when quick question is selected | |
| quick_question.change( | |
| fn=lambda x: x if x else "", | |
| inputs=[quick_question], | |
| outputs=[question_input] | |
| ) | |
| with gr.Row(): | |
| consult_btn = gr.Button("π Get Dental Advice", variant="primary") | |
| clear_btn = gr.Button("ποΈ Clear") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### βοΈ Settings") | |
| max_tokens = gr.Slider( | |
| minimum=500, | |
| maximum=4000, | |
| value=2048, | |
| step=100, | |
| label="Max Response Tokens" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature (Creativity)" | |
| ) | |
| gr.Markdown(""" | |
| **Model Info:** | |
| - Using Transformers Model | |
| - Optimized for GPU/CPU | |
| - Auto device mapping | |
| """) | |
| response_output = gr.Textbox( | |
| label="π©Ί AI Response:", | |
| lines=15, | |
| max_lines=25 | |
| ) | |
| consult_btn.click( | |
| fn=dental_consultation_interface, | |
| inputs=[question_input, max_tokens, temperature], | |
| outputs=[response_output] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ("", "", ""), | |
| inputs=[], | |
| outputs=[question_input, response_output] | |
| ) | |
| # Tab 2: Medication Extraction | |
| with gr.TabItem("π Medication Extraction"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| extraction_text = gr.Textbox( | |
| label="Enter text for medication extraction:", | |
| placeholder="Paste medical text here to extract medication information...", | |
| lines=10 | |
| ) | |
| gemini_api_key = gr.Textbox( | |
| label="π Gemini API Key", | |
| placeholder="AIza...", | |
| type="password", | |
| info="Required for medication extraction" | |
| ) | |
| with gr.Row(): | |
| extract_btn = gr.Button("𧬠Extract Medications", variant="primary") | |
| copy_from_consultation = gr.Button("π Copy from Consultation") | |
| with gr.Row(): | |
| with gr.Column(): | |
| extraction_results = gr.Textbox( | |
| label="π Extracted Medications:", | |
| lines=8 | |
| ) | |
| with gr.Column(): | |
| highlighted_text = gr.Textbox( | |
| label="π― Highlighted Text:", | |
| lines=8 | |
| ) | |
| with gr.Row(): | |
| visualization_html = gr.HTML( | |
| label="π¨ Interactive Visualization:", | |
| value="<p style='text-align: center; color: #666;'>Visualization will appear here after extraction</p>" | |
| ) | |
| extract_btn.click( | |
| fn=medication_extraction_interface, | |
| inputs=[extraction_text, gemini_api_key], | |
| outputs=[extraction_results, highlighted_text, visualization_html] | |
| ) | |
| # Copy response from consultation tab to extraction | |
| copy_from_consultation.click( | |
| fn=lambda x: x, | |
| inputs=[response_output], | |
| outputs=[extraction_text] | |
| ) | |
| # Tab 3: Help & Setup | |
| with gr.TabItem("π Help & Setup"): | |
| gr.Markdown(""" | |
| ## π Getting Started | |
| ### Model: | |
| **Transformers Model**: Uses HuggingFace transformers library with automatic device mapping | |
| ### π API Key Setup: | |
| **Gemini API Key** (required for medication extraction): | |
| 1. Go to [Google AI Studio](https://aistudio.google.com) | |
| 2. Click 'Get API Key' | |
| 3. Create a new API key | |
| ### π¦ Installation Requirements: | |
| ```bash | |
| pip install gradio transformers langextract pandas requests torch | |
| ``` | |
| ### π©Ί Features: | |
| - **Dental Consultation**: Get AI-powered dental advice with detailed medication regimens | |
| - **Medication Extraction**: Extract and highlight medications from medical text | |
| - **Interactive Visualization**: Visual representation of extracted medication entities | |
| - **Quick Questions**: Pre-built common dental questions | |
| - **Customizable Settings**: Adjust response length and creativity | |
| - **GPU/CPU Support**: Automatic device detection and optimization | |
| ### β οΈ Important Disclaimer: | |
| This AI assistant is for educational purposes only. Always consult with a qualified dentist for professional medical advice. | |
| """) | |
| # Footer | |
| gr.HTML(""" | |
| <div class="disclaimer"> | |
| <p><strong>β οΈ Disclaimer:</strong> This AI assistant is for educational purposes only. | |
| Always consult with a qualified dentist for professional medical advice.</p> | |
| <p style="text-align: center; margin-top: 1rem;"> | |
| π¦· Built with Gradio | Gemini | Powered by yasserrmd/DentaInstruct-1.2B | |
| </p> | |
| </div> | |
| """) | |
| return demo | |
| if __name__ == "__main__": | |
| # Create and launch the interface | |
| demo = create_gradio_interface() | |
| demo.queue() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True, | |
| ssr_mode=False # Disable SSR for Spaces compatibility | |
| ) |