| | import gradio as gr |
| | import argparse |
| | import torch |
| | from torch import cuda |
| | import torch.nn.functional as F |
| | import torchvision.transforms.functional as TF |
| | from torchvision import transforms |
| | from PIL import Image |
| | import skimage.morphology, skimage.io |
| | import cv2 |
| | import numpy as np |
| | import random |
| | from transformers import StoppingCriteria, StoppingCriteriaList |
| | from copy import deepcopy |
| | from medomni.common.config import Config |
| | from medomni.common.dist_utils import get_rank |
| | from medomni.common.registry import registry |
| | import torchio as tio |
| | import nibabel as nib |
| | from scipy import ndimage, misc |
| | import time |
| | import ipdb |
| |
|
| | |
| | def parse_args(): |
| | parser = argparse.ArgumentParser(description="Demo") |
| | parser.add_argument("--cfg-path", required=True, help="path to configuration file.") |
| | parser.add_argument( |
| | "--options", |
| | nargs="+", |
| | help="override some settings in the used config, the key-value pair in xxx=yyy format will be merged into config file (deprecate), change to --cfg-options instead.", |
| | ) |
| | args = parser.parse_args() |
| | return args |
| |
|
| | device = 'cuda' if cuda.is_available() else 'cpu' |
| | |
| | args = parse_args() |
| | cfg = Config(args) |
| |
|
| | model_config = cfg.model_cfg |
| | model_cls = registry.get_model_class(model_config.arch) |
| | model = model_cls.from_pretrained('hyzhou/MedVersa').to(device).eval() |
| | global global_images |
| | global_images = None |
| |
|
| | def seg_2d_process(image_path, pred_mask, img_size=224): |
| | image = cv2.imread(image_path[0]) |
| | if pred_mask.sum() != 0: |
| | labels = skimage.morphology.label(pred_mask) |
| | labelCount = np.bincount(labels.ravel()) |
| | largest_label = np.argmax(labelCount[1:]) + 1 |
| | pred_mask[labels != largest_label] = 0 |
| | pred_mask[labels == largest_label] = 255 |
| | pred_mask = pred_mask.astype(np.uint8) |
| | contours, _ = cv2.findContours(pred_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) |
| | if contours: |
| | contours = np.vstack(contours) |
| | binary_array = np.zeros((img_size, img_size)) |
| | binary_array = cv2.drawContours(binary_array, contours, -1, 255, thickness=cv2.FILLED) |
| | binary_array = cv2.resize(binary_array, (image.shape[1], image.shape[0]), interpolation = cv2.INTER_NEAREST) / 255 |
| | image = [Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))] |
| | mask = [binary_array] |
| | else: |
| | image = [Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))] |
| | mask = [np.zeros((image.shape[1], image.shape[0]))] |
| | else: |
| | mask = [np.zeros((image.shape[1], image.shape[0]))] |
| | image = [Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))] |
| | |
| | |
| | return image, mask |
| |
|
| | def seg_3d_process(image_path, seg_mask): |
| | img = nib.load(image_path[0]).get_fdata() |
| | image = window_scan(img).transpose(2,0,1).astype(np.uint8) |
| | if seg_mask.sum() != 0: |
| | seg_mask = resize_back_volume_abd(seg_mask, image.shape).astype(np.uint8) |
| | image_slices = [] |
| | contour_slices = [] |
| | for i in range(seg_mask.shape[0]): |
| | slice_img = np.fliplr(np.rot90(image[i])) |
| | slice_mask = np.fliplr(np.rot90(seg_mask[i])) |
| | contours, _ = cv2.findContours(slice_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) |
| | image_slices.append(Image.fromarray(slice_img)) |
| | if contours: |
| | binary_array = np.zeros(seg_mask.shape[1:]) |
| | binary_array = cv2.drawContours(binary_array, contours, -1, 255, thickness=cv2.FILLED) / 255 |
| | binary_array = cv2.resize(binary_array, slice_img.shape, interpolation = cv2.INTER_NEAREST) |
| | contour_slices.append(binary_array) |
| | else: |
| | contour_slices.append(np.zeros_like(slice_img)) |
| | else: |
| | image_slices = [] |
| | contour_slices = [] |
| | slice_img = np.fliplr(np.rot90(image[i])) |
| | image_slices.append(Image.fromarray(slice_img)) |
| | contour_slices.append(np.zeros_like(slice_img)) |
| |
|
| | return image_slices, contour_slices |
| |
|
| | def det_2d_process(image_path, box): |
| | image_slices = [] |
| | image = cv2.imread(image_path[0]) |
| | if box is not None: |
| | hi,wd,_ = image.shape |
| | color = tuple(np.random.random(size=3) * 256) |
| | x1, y1, x2, y2 = int(box[0]*wd), int(box[1]*hi), int(box[2]*wd), int(box[3]*hi) |
| | image = cv2.rectangle(image, (x1, y1), (x2, y2), color, 10) |
| | image_slices.append(Image.fromarray(image)) |
| | return image_slices |
| |
|
| | def window_scan(scan, window_center=50, window_width=400): |
| | """ |
| | Apply windowing to a scan. |
| | |
| | Parameters: |
| | scan (numpy.ndarray): 3D numpy array of the CT scan |
| | window_center (int): The center of the window |
| | window_width (int): The width of the window |
| | |
| | Returns: |
| | numpy.ndarray: Windowed CT scan |
| | """ |
| | lower_bound = window_center - (window_width // 2) |
| | upper_bound = window_center + (window_width // 2) |
| | |
| | windowed_scan = np.clip(scan, lower_bound, upper_bound) |
| | windowed_scan = (windowed_scan - lower_bound) / (upper_bound - lower_bound) |
| | windowed_scan = (windowed_scan * 255).astype(np.uint8) |
| | |
| | return windowed_scan |
| |
|
| | def task_seg_2d(model, preds, hidden_states, image): |
| | token_mask = preds == model.seg_token_idx_2d |
| | indices = torch.where(token_mask == True)[0].cpu().numpy() |
| | feats = model.model_seg_2d.encoder(image.unsqueeze(0)[:, 0]) |
| | last_feats = feats[-1] |
| | target_states = [hidden_states[ind][-1] for ind in indices] |
| | if target_states: |
| | target_states = torch.cat(target_states).squeeze() |
| | seg_states = model.text2seg_2d(target_states).unsqueeze(0) |
| | last_feats = last_feats + seg_states.unsqueeze(-1).unsqueeze(-1) |
| | last_feats = model.text2seg_2d_gn(last_feats) |
| | feats[-1] = last_feats |
| | seg_feats = model.model_seg_2d.decoder(*feats) |
| | seg_preds = model.model_seg_2d.segmentation_head(seg_feats) |
| | seg_probs = F.sigmoid(seg_preds) |
| | seg_mask = seg_probs.to(torch.float32).cpu().squeeze().numpy() >= 0.5 |
| | return seg_mask |
| | else: |
| | return None |
| |
|
| | def task_seg_3d(model, preds, hidden_states, img_embeds_list): |
| | new_img_embeds_list = deepcopy(img_embeds_list) |
| | token_mask = preds == model.seg_token_idx_3d |
| | indices = torch.where(token_mask == True)[0].cpu().numpy() |
| | target_states = [hidden_states[ind][-1] for ind in indices] |
| | if target_states: |
| | target_states = torch.cat(target_states).squeeze().unsqueeze(0) |
| | seg_states = model.text2seg_3d(target_states) |
| | last_feats = new_img_embeds_list[-1] |
| | last_feats = last_feats + seg_states.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
| | last_feats = model.text2seg_3d_gn(last_feats) |
| | new_img_embeds_list[-1] = last_feats |
| | seg_preds = model.visual_encoder_3d(encoder_only=False, x_=new_img_embeds_list) |
| | seg_probs = F.sigmoid(seg_preds) |
| | seg_mask = seg_probs.to(torch.float32).cpu().squeeze().numpy() >= 0.5 |
| | return seg_mask |
| |
|
| | def task_det_2d(model, preds, hidden_states): |
| | token_mask = preds == model.det_token_idx |
| | indices = torch.where(token_mask == True)[0].cpu().numpy() |
| | target_states = [hidden_states[ind][-1] for ind in indices] |
| | if target_states: |
| | target_states = torch.cat(target_states).squeeze() |
| | det_states = model.text_det(target_states).detach().cpu() |
| | return det_states.to(torch.float32).numpy() |
| | return torch.zeros_like(indices) |
| |
|
| | class StoppingCriteriaSub(StoppingCriteria): |
| | def __init__(self, stops=[]): |
| | super().__init__() |
| | self.stops = stops |
| |
|
| | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): |
| | for stop in self.stops: |
| | if torch.all((stop == input_ids[0][-len(stop):])).item(): |
| | return True |
| | return False |
| |
|
| | def resize_back_volume_abd(img, target_size): |
| | desired_depth = target_size[0] |
| | desired_width = target_size[1] |
| | desired_height = target_size[2] |
| |
|
| | current_depth = img.shape[0] |
| | current_width = img.shape[1] |
| | current_height = img.shape[2] |
| | |
| | depth = current_depth / desired_depth |
| | width = current_width / desired_width |
| | height = current_height / desired_height |
| | |
| | depth_factor = 1 / depth |
| | width_factor = 1 / width |
| | height_factor = 1 / height |
| |
|
| | img = ndimage.zoom(img, (depth_factor, width_factor, height_factor), order=0) |
| | return img |
| |
|
| | def resize_volume_abd(img): |
| | img[img<=-200] = -200 |
| | img[img>=300] = 300 |
| |
|
| | desired_depth = 64 |
| | desired_width = 192 |
| | desired_height = 192 |
| |
|
| | current_width = img.shape[0] |
| | current_height = img.shape[1] |
| | current_depth = img.shape[2] |
| | |
| | depth = current_depth / desired_depth |
| | width = current_width / desired_width |
| | height = current_height / desired_height |
| | |
| | depth_factor = 1 / depth |
| | width_factor = 1 / width |
| | height_factor = 1 / height |
| |
|
| | img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order=0) |
| | return img |
| |
|
| | def load_and_preprocess_image(image): |
| | mean = (0.48145466, 0.4578275, 0.40821073) |
| | std = (0.26862954, 0.26130258, 0.27577711) |
| | transform = transforms.Compose([ |
| | transforms.Resize([224, 224]), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean, std) |
| | ]) |
| | image = transform(image).type(torch.bfloat16).unsqueeze(0) |
| | return image |
| |
|
| | def load_and_preprocess_volume(image): |
| | img = nib.load(image).get_fdata() |
| | image = torch.from_numpy(resize_volume_abd(img)).permute(2,0,1) |
| | transform = tio.Compose([ |
| | tio.ZNormalization(masking_method=tio.ZNormalization.mean), |
| | ]) |
| | image = transform(image.unsqueeze(0)).type(torch.bfloat16) |
| | return image |
| |
|
| | def read_image(image_path): |
| | if image_path.endswith(('.jpg', '.jpeg', '.png')): |
| | return load_and_preprocess_image(Image.open(image_path).convert('RGB')) |
| | elif image_path.endswith('.nii.gz'): |
| | return load_and_preprocess_volume(image_path) |
| | else: |
| | raise ValueError("Unsupported file format") |
| |
|
| | def generate(image_path, image, context, modal, num_imgs, prompt, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature): |
| | if (len(context) != 0 and ('report' in prompt or 'finding' in prompt or 'impression' in prompt)) or (len(context) != 0 and modal=='derm' and ('diagnosis' in prompt or 'issue' in prompt or 'problem' in prompt)): |
| | prompt = '<context>' + context + '</context>' + prompt |
| | if modal == 'ct' and 'segment' in prompt.lower(): |
| | if 'liver' in prompt: |
| | prompt = 'Segment the liver.' |
| | if 'spleen' in prompt: |
| | prompt = 'Segment the spleen.' |
| | if 'kidney' in prompt: |
| | prompt = 'Segment the kidney.' |
| | if 'pancrea' in prompt: |
| | prompt = 'Segment the pancreas.' |
| | img_embeds, atts_img, img_embeds_list = model.encode_img(image.unsqueeze(0), [modal]) |
| | placeholder = ['<ImageHere>'] * 9 |
| | prefix = '###Human:' + ''.join([f'<img{i}>' + ''.join(placeholder) + f'</img{i}>' for i in range(num_imgs)]) |
| | img_embeds, atts_img = model.prompt_wrap(img_embeds, atts_img, [prefix], [num_imgs]) |
| | prompt += '###Assistant:' |
| | prompt_tokens = model.llama_tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(image.device) |
| | new_img_embeds, new_atts_img = model.prompt_concat(img_embeds, atts_img, prompt_tokens) |
| | |
| | outputs = model.llama_model.generate( |
| | inputs_embeds=new_img_embeds, |
| | max_new_tokens=450, |
| | stopping_criteria=StoppingCriteriaList([StoppingCriteriaSub(stops=[ |
| | torch.tensor([835]).type(torch.bfloat16).to(image.device), |
| | torch.tensor([2277, 29937]).type(torch.bfloat16).to(image.device) |
| | ])]), |
| | num_beams=num_beams, |
| | do_sample=do_sample, |
| | min_length=min_length, |
| | top_p=top_p, |
| | repetition_penalty=repetition_penalty, |
| | length_penalty=length_penalty, |
| | temperature=temperature, |
| | output_hidden_states=True, |
| | return_dict_in_generate=True, |
| | ) |
| | |
| | hidden_states = outputs.hidden_states |
| | preds = outputs.sequences[0] |
| | output_image = None |
| | seg_mask_2d = None |
| | seg_mask_3d = None |
| | if sum(preds == model.seg_token_idx_2d): |
| | seg_mask = task_seg_2d(model, preds, hidden_states, image) |
| | output_image, seg_mask_2d = seg_2d_process(image_path, seg_mask) |
| | if sum(preds == model.seg_token_idx_3d): |
| | seg_mask = task_seg_3d(model, preds, hidden_states, img_embeds_list) |
| | output_image, seg_mask_3d = seg_3d_process(image_path, seg_mask) |
| | if sum(preds == model.det_token_idx): |
| | det_box = task_det_2d(model, preds, hidden_states) |
| | output_image = det_2d_process(image_path, det_box) |
| | |
| | if preds[0] == 0: |
| | preds = preds[1:] |
| | if preds[0] == 1: |
| | preds = preds[1:] |
| | |
| | output_text = model.llama_tokenizer.decode(preds, add_special_tokens=False) |
| | output_text = output_text.split('###')[0].split('Assistant:')[-1].strip() |
| |
|
| | if 'mel' in output_text and modal == 'derm': |
| | output_text = 'The main diagnosis is melanoma.' |
| | return output_image, seg_mask_2d, seg_mask_3d, output_text |
| |
|
| | def generate_predictions(images, context, prompt, modality, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature): |
| | num_imgs = len(images) |
| | modal = modality.lower() |
| | image_tensors = [read_image(img).to(device) for img in images] |
| | if modality == 'ct': |
| | time.sleep(2) |
| | else: |
| | time.sleep(1) |
| | image_tensor = torch.cat(image_tensors) |
| | |
| | with torch.autocast(device): |
| | with torch.no_grad(): |
| | generated_image, seg_mask_2d, seg_mask_3d, output_text = generate(images, image_tensor, context, modal, num_imgs, prompt, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature) |
| | |
| | return generated_image, seg_mask_2d, seg_mask_3d, output_text |
| |
|
| | my_dict = {} |
| | def gradio_interface(chatbot, images, context, prompt, modality, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature): |
| | global global_images |
| | if not images: |
| | image = np.zeros((224, 224, 3), dtype=np.uint8) |
| | blank_image = Image.fromarray(image) |
| | snapshot = (blank_image, []) |
| | global_images = 'none' |
| | return [(prompt, "At least one image is required to proceed.")], snapshot, gr.update(maximum=0) |
| | if not prompt or not modality: |
| | image = np.zeros((224, 224, 3), dtype=np.uint8) |
| | blank_image = Image.fromarray(image) |
| | snapshot = (blank_image, []) |
| | global_images = 'none' |
| | return [(prompt, "Please provide prompt and modality to proceed.")], snapshot, gr.update(maximum=0) |
| |
|
| | generated_images, seg_mask_2d, seg_mask_3d, output_text = generate_predictions(images, context, prompt, modality, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature) |
| | output_images = [] |
| | input_images = [np.asarray(Image.open(img.name).convert('RGB')).astype(np.uint8) if img.name.endswith(('.jpg', '.jpeg', '.png')) else f"{img.name} (3D Volume)" for img in images] |
| | if generated_images is not None: |
| | for generated_image in generated_images: |
| | output_images.append(np.asarray(generated_image).astype(np.uint8)) |
| | snapshot = (output_images[0], []) |
| | if seg_mask_2d is not None: |
| | snapshot = (output_images[0], [(seg_mask_2d[0], "Mask")]) |
| | if seg_mask_3d is not None: |
| | snapshot = (output_images[0], [(seg_mask_3d[0], "Mask")]) |
| | else: |
| | output_images = input_images.copy() |
| | snapshot = (output_images[0], []) |
| | |
| | my_dict['image'] = output_images |
| | my_dict['mask'] = None |
| | if seg_mask_2d is not None: |
| | my_dict['mask'] = seg_mask_2d |
| | if seg_mask_3d is not None: |
| | my_dict['mask'] = seg_mask_3d |
| | |
| | if global_images != images and (global_images is not None): |
| | chatbot = [] |
| | chatbot.append((prompt, output_text)) |
| | else: |
| | chatbot.append((prompt, output_text)) |
| | global_images = images |
| |
|
| | return chatbot, snapshot, gr.update(maximum=len(output_images)-1) |
| |
|
| | def render(x): |
| | if x > len(my_dict['image'])-1: |
| | x = len(my_dict['image'])-1 |
| | if x < 0: |
| | x = 0 |
| | image = my_dict['image'][x] |
| | if my_dict['mask'] is None: |
| | return (image,[]) |
| | else: |
| | mask = my_dict['mask'][x] |
| | value = (image,[(mask, "Mask")]) |
| | return value |
| |
|
| | def update_context_visibility(task): |
| | if task == "report generation" or task == 'classification': |
| | return gr.update(visible=True) |
| | else: |
| | return gr.update(visible=False) |
| |
|
| | def reset_chatbot(): |
| | return [] |
| |
|
| | with gr.Blocks(theme=gr.themes.Soft()) as demo: |
| | |
| | |
| | gr.Markdown("# MedVersa") |
| | with gr.Row(): |
| | with gr.Column(): |
| | image_input = gr.File(label="Upload Images", file_count="multiple", file_types=["image", "numpy"]) |
| | |
| | context_input = gr.Textbox(label="Context", placeholder="Enter context here...", lines=3, visible=True) |
| | modality_input = gr.Dropdown(choices=["cxr", "derm", "ct"], label="Modality") |
| | prompt_input = gr.Textbox(label="Prompt", placeholder="Enter prompt here... (images should be referred as <img0>, <img1>, ...)", lines=3) |
| | submit_button = gr.Button("Generate Predictions") |
| | with gr.Accordion("Advanced Settings", open=False): |
| | num_beams = gr.Slider(label="Number of Beams", minimum=1, maximum=10, step=1, value=1) |
| | do_sample = gr.Checkbox(label="Do Sample", value=True) |
| | min_length = gr.Slider(label="Minimum Length", minimum=1, maximum=100, step=1, value=1) |
| | top_p = gr.Slider(label="Top P", minimum=0.1, maximum=1.0, step=0.1, value=0.9) |
| | repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.1, value=1.0) |
| | length_penalty = gr.Slider(label="Length Penalty", minimum=1.0, maximum=2.0, step=0.1, value=1.0) |
| | temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, step=0.1, value=0.1) |
| | |
| | with gr.Column(): |
| | |
| | chatbot = gr.Chatbot(label="Chatbox") |
| | slider = gr.Slider(minimum=0, maximum=64, value=1, step=1) |
| | output_image = gr.AnnotatedImage(height=448, label="Images") |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | submit_button.click( |
| | fn=gradio_interface, |
| | inputs=[chatbot, image_input, context_input, prompt_input, modality_input, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature], |
| | outputs=[chatbot, output_image, slider] |
| | ) |
| |
|
| | slider.change( |
| | render, |
| | inputs=[slider], |
| | outputs=[output_image], |
| | ) |
| |
|
| | examples = [ |
| | [ |
| | ["./demo_ex/c536f749-2326f755-6a65f28f-469affd2-26392ce9.png"], |
| | "Age:30-40.\nGender:F.\nIndication: ___-year-old female with end-stage renal disease not on dialysis presents with dyspnea. PICC line placement.\nComparison: None.", |
| | "How would you characterize the findings from <img0>?", |
| | "cxr", |
| | ], |
| | [ |
| | ["./demo_ex/79eee504-b1b60ab8-5e8dd843-b6ed87aa-670747b1.png"], |
| | "Age:70-80.\nGender:F.\nIndication: Respiratory distress.\nComparison: None.", |
| | "How would you characterize the findings from <img0>?", |
| | "cxr", |
| | ], |
| | [ |
| | ["./demo_ex/f39b05b1-f544e51a-cfe317ca-b66a4aa6-1c1dc22d.png", "./demo_ex/f3fefc29-68544ac8-284b820d-858b5470-f579b982.png"], |
| | "Age:80-90.\nGender:F.\nIndication: ___-year-old female with history of chest pain.\nComparison: None.", |
| | "How would you characterize the findings from <img0><img1>?", |
| | "cxr", |
| | ], |
| | [ |
| | ["./demo_ex/1de015eb-891f1b02-f90be378-d6af1e86-df3270c2.png"], |
| | "Age:40-50.\nGender:M.\nIndication: ___-year-old male with shortness of breath.\nComparison: None.", |
| | "How would you characterize the findings from <img0>?", |
| | "cxr", |
| | ], |
| | [ |
| | ["./demo_ex/bc25fa99-0d3766cc-7704edb7-5c7a4a63-dc65480a.png"], |
| | "Age:40-50.\nGender:F.\nIndication: History: ___F with tachyacrdia cough doe // infilatrate\nComparison: None.", |
| | "How would you characterize the findings from <img0>?", |
| | "cxr", |
| | ], |
| | [ |
| | ["./demo_ex/ISIC_0032258.jpg"], |
| | "Age:70.\nGender:female.\nLocation:back.", |
| | "What is primary diagnosis?", |
| | "derm", |
| | ], |
| | [ |
| | ["./demo_ex/Case_01013_0000.nii.gz"], |
| | "", |
| | "Segment the liver.", |
| | "ct", |
| | ], |
| | [ |
| | ["./demo_ex/Case_00840_0000.nii.gz"], |
| | "", |
| | "Segment the liver.", |
| | "ct", |
| | ], |
| | ] |
| |
|
| | gr.Examples(examples, inputs=[image_input, context_input, prompt_input, modality_input]) |
| |
|
| | |
| | demo.launch(share=True) |