Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,64 +1,512 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
top_p=top_p,
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import argparse
|
| 3 |
+
import torch
|
| 4 |
+
from torch import cuda
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import torchvision.transforms.functional as TF
|
| 7 |
+
from torchvision import transforms
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import skimage.morphology, skimage.io
|
| 10 |
+
import cv2
|
| 11 |
+
import numpy as np
|
| 12 |
+
import random
|
| 13 |
+
from transformers import StoppingCriteria, StoppingCriteriaList
|
| 14 |
+
from copy import deepcopy
|
| 15 |
+
from medomni.common.config import Config
|
| 16 |
+
from medomni.common.dist_utils import get_rank
|
| 17 |
+
from medomni.common.registry import registry
|
| 18 |
+
import torchio as tio
|
| 19 |
+
import nibabel as nib
|
| 20 |
+
from scipy import ndimage, misc
|
| 21 |
+
import time
|
| 22 |
+
import ipdb
|
| 23 |
+
|
| 24 |
+
# Function to parse command line arguments
|
| 25 |
+
def parse_args():
|
| 26 |
+
parser = argparse.ArgumentParser(description="Demo")
|
| 27 |
+
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--options",
|
| 30 |
+
nargs="+",
|
| 31 |
+
help="override some settings in the used config, the key-value pair in xxx=yyy format will be merged into config file (deprecate), change to --cfg-options instead.",
|
| 32 |
+
)
|
| 33 |
+
args = parser.parse_args()
|
| 34 |
+
return args
|
| 35 |
+
|
| 36 |
+
device = 'cuda' if cuda.is_available() else 'cpu'
|
| 37 |
+
# Launch model
|
| 38 |
+
args = parse_args()
|
| 39 |
+
cfg = Config(args)
|
| 40 |
+
|
| 41 |
+
model_config = cfg.model_cfg
|
| 42 |
+
model_cls = registry.get_model_class(model_config.arch)
|
| 43 |
+
model = model_cls.from_pretrained('hyzhou/MedVersa').to(device).eval()
|
| 44 |
+
global global_images
|
| 45 |
+
global_images = None
|
| 46 |
+
|
| 47 |
+
def seg_2d_process(image_path, pred_mask, img_size=224):
|
| 48 |
+
image = cv2.imread(image_path[0])
|
| 49 |
+
if pred_mask.sum() != 0:
|
| 50 |
+
labels = skimage.morphology.label(pred_mask)
|
| 51 |
+
labelCount = np.bincount(labels.ravel())
|
| 52 |
+
largest_label = np.argmax(labelCount[1:]) + 1
|
| 53 |
+
pred_mask[labels != largest_label] = 0
|
| 54 |
+
pred_mask[labels == largest_label] = 255
|
| 55 |
+
pred_mask = pred_mask.astype(np.uint8)
|
| 56 |
+
contours, _ = cv2.findContours(pred_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
|
| 57 |
+
if contours:
|
| 58 |
+
contours = np.vstack(contours)
|
| 59 |
+
binary_array = np.zeros((img_size, img_size))
|
| 60 |
+
binary_array = cv2.drawContours(binary_array, contours, -1, 255, thickness=cv2.FILLED)
|
| 61 |
+
binary_array = cv2.resize(binary_array, (image.shape[1], image.shape[0]), interpolation = cv2.INTER_NEAREST) / 255
|
| 62 |
+
image = [Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))]
|
| 63 |
+
mask = [binary_array]
|
| 64 |
+
else:
|
| 65 |
+
image = [Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))]
|
| 66 |
+
mask = [np.zeros((image.shape[1], image.shape[0]))]
|
| 67 |
+
else:
|
| 68 |
+
mask = [np.zeros((image.shape[1], image.shape[0]))]
|
| 69 |
+
image = [Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))]
|
| 70 |
+
# output_image = cv2.drawContours(binary_array, contours, -1, (110, 0, 255), 2)
|
| 71 |
+
# output_image_pil = Image.fromarray(cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB))
|
| 72 |
+
return image, mask
|
| 73 |
+
|
| 74 |
+
def seg_3d_process(image_path, seg_mask):
|
| 75 |
+
img = nib.load(image_path[0]).get_fdata()
|
| 76 |
+
image = window_scan(img).transpose(2,0,1).astype(np.uint8)
|
| 77 |
+
if seg_mask.sum() != 0:
|
| 78 |
+
seg_mask = resize_back_volume_abd(seg_mask, image.shape).astype(np.uint8)
|
| 79 |
+
image_slices = []
|
| 80 |
+
contour_slices = []
|
| 81 |
+
for i in range(seg_mask.shape[0]):
|
| 82 |
+
slice_img = np.fliplr(np.rot90(image[i]))
|
| 83 |
+
slice_mask = np.fliplr(np.rot90(seg_mask[i]))
|
| 84 |
+
contours, _ = cv2.findContours(slice_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
|
| 85 |
+
image_slices.append(Image.fromarray(slice_img))
|
| 86 |
+
if contours:
|
| 87 |
+
binary_array = np.zeros(seg_mask.shape[1:])
|
| 88 |
+
binary_array = cv2.drawContours(binary_array, contours, -1, 255, thickness=cv2.FILLED) / 255
|
| 89 |
+
binary_array = cv2.resize(binary_array, slice_img.shape, interpolation = cv2.INTER_NEAREST)
|
| 90 |
+
contour_slices.append(binary_array)
|
| 91 |
+
else:
|
| 92 |
+
contour_slices.append(np.zeros_like(slice_img))
|
| 93 |
+
else:
|
| 94 |
+
image_slices = []
|
| 95 |
+
contour_slices = []
|
| 96 |
+
slice_img = np.fliplr(np.rot90(image[i]))
|
| 97 |
+
image_slices.append(Image.fromarray(slice_img))
|
| 98 |
+
contour_slices.append(np.zeros_like(slice_img))
|
| 99 |
+
|
| 100 |
+
return image_slices, contour_slices
|
| 101 |
+
|
| 102 |
+
def det_2d_process(image_path, box):
|
| 103 |
+
image_slices = []
|
| 104 |
+
image = cv2.imread(image_path[0])
|
| 105 |
+
if box is not None:
|
| 106 |
+
hi,wd,_ = image.shape
|
| 107 |
+
color = tuple(np.random.random(size=3) * 256)
|
| 108 |
+
x1, y1, x2, y2 = int(box[0]*wd), int(box[1]*hi), int(box[2]*wd), int(box[3]*hi)
|
| 109 |
+
image = cv2.rectangle(image, (x1, y1), (x2, y2), color, 10)
|
| 110 |
+
image_slices.append(Image.fromarray(image))
|
| 111 |
+
return image_slices
|
| 112 |
+
|
| 113 |
+
def window_scan(scan, window_center=50, window_width=400):
|
| 114 |
+
"""
|
| 115 |
+
Apply windowing to a scan.
|
| 116 |
+
|
| 117 |
+
Parameters:
|
| 118 |
+
scan (numpy.ndarray): 3D numpy array of the CT scan
|
| 119 |
+
window_center (int): The center of the window
|
| 120 |
+
window_width (int): The width of the window
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
numpy.ndarray: Windowed CT scan
|
| 124 |
+
"""
|
| 125 |
+
lower_bound = window_center - (window_width // 2)
|
| 126 |
+
upper_bound = window_center + (window_width // 2)
|
| 127 |
+
|
| 128 |
+
windowed_scan = np.clip(scan, lower_bound, upper_bound)
|
| 129 |
+
windowed_scan = (windowed_scan - lower_bound) / (upper_bound - lower_bound)
|
| 130 |
+
windowed_scan = (windowed_scan * 255).astype(np.uint8)
|
| 131 |
+
|
| 132 |
+
return windowed_scan
|
| 133 |
+
|
| 134 |
+
def task_seg_2d(model, preds, hidden_states, image):
|
| 135 |
+
token_mask = preds == model.seg_token_idx_2d
|
| 136 |
+
indices = torch.where(token_mask == True)[0].cpu().numpy()
|
| 137 |
+
feats = model.model_seg_2d.encoder(image.unsqueeze(0)[:, 0])
|
| 138 |
+
last_feats = feats[-1]
|
| 139 |
+
target_states = [hidden_states[ind][-1] for ind in indices]
|
| 140 |
+
if target_states:
|
| 141 |
+
target_states = torch.cat(target_states).squeeze()
|
| 142 |
+
seg_states = model.text2seg_2d(target_states).unsqueeze(0)
|
| 143 |
+
last_feats = last_feats + seg_states.unsqueeze(-1).unsqueeze(-1)
|
| 144 |
+
last_feats = model.text2seg_2d_gn(last_feats)
|
| 145 |
+
feats[-1] = last_feats
|
| 146 |
+
seg_feats = model.model_seg_2d.decoder(*feats)
|
| 147 |
+
seg_preds = model.model_seg_2d.segmentation_head(seg_feats)
|
| 148 |
+
seg_probs = F.sigmoid(seg_preds)
|
| 149 |
+
seg_mask = seg_probs.to(torch.float32).cpu().squeeze().numpy() >= 0.5
|
| 150 |
+
return seg_mask
|
| 151 |
+
else:
|
| 152 |
+
return None
|
| 153 |
+
|
| 154 |
+
def task_seg_3d(model, preds, hidden_states, img_embeds_list):
|
| 155 |
+
new_img_embeds_list = deepcopy(img_embeds_list)
|
| 156 |
+
token_mask = preds == model.seg_token_idx_3d
|
| 157 |
+
indices = torch.where(token_mask == True)[0].cpu().numpy()
|
| 158 |
+
target_states = [hidden_states[ind][-1] for ind in indices]
|
| 159 |
+
if target_states:
|
| 160 |
+
target_states = torch.cat(target_states).squeeze().unsqueeze(0)
|
| 161 |
+
seg_states = model.text2seg_3d(target_states)
|
| 162 |
+
last_feats = new_img_embeds_list[-1]
|
| 163 |
+
last_feats = last_feats + seg_states.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
| 164 |
+
last_feats = model.text2seg_3d_gn(last_feats)
|
| 165 |
+
new_img_embeds_list[-1] = last_feats
|
| 166 |
+
seg_preds = model.visual_encoder_3d(encoder_only=False, x_=new_img_embeds_list)
|
| 167 |
+
seg_probs = F.sigmoid(seg_preds)
|
| 168 |
+
seg_mask = seg_probs.to(torch.float32).cpu().squeeze().numpy() >= 0.5
|
| 169 |
+
return seg_mask
|
| 170 |
+
|
| 171 |
+
def task_det_2d(model, preds, hidden_states):
|
| 172 |
+
token_mask = preds == model.det_token_idx
|
| 173 |
+
indices = torch.where(token_mask == True)[0].cpu().numpy()
|
| 174 |
+
target_states = [hidden_states[ind][-1] for ind in indices]
|
| 175 |
+
if target_states:
|
| 176 |
+
target_states = torch.cat(target_states).squeeze()
|
| 177 |
+
det_states = model.text_det(target_states).detach().cpu()
|
| 178 |
+
return det_states.to(torch.float32).numpy()
|
| 179 |
+
return torch.zeros_like(indices)
|
| 180 |
+
|
| 181 |
+
class StoppingCriteriaSub(StoppingCriteria):
|
| 182 |
+
def __init__(self, stops=[]):
|
| 183 |
+
super().__init__()
|
| 184 |
+
self.stops = stops
|
| 185 |
+
|
| 186 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
| 187 |
+
for stop in self.stops:
|
| 188 |
+
if torch.all((stop == input_ids[0][-len(stop):])).item():
|
| 189 |
+
return True
|
| 190 |
+
return False
|
| 191 |
+
|
| 192 |
+
def resize_back_volume_abd(img, target_size):
|
| 193 |
+
desired_depth = target_size[0]
|
| 194 |
+
desired_width = target_size[1]
|
| 195 |
+
desired_height = target_size[2]
|
| 196 |
+
|
| 197 |
+
current_depth = img.shape[0] # [d, w, h]
|
| 198 |
+
current_width = img.shape[1]
|
| 199 |
+
current_height = img.shape[2]
|
| 200 |
+
|
| 201 |
+
depth = current_depth / desired_depth
|
| 202 |
+
width = current_width / desired_width
|
| 203 |
+
height = current_height / desired_height
|
| 204 |
+
|
| 205 |
+
depth_factor = 1 / depth
|
| 206 |
+
width_factor = 1 / width
|
| 207 |
+
height_factor = 1 / height
|
| 208 |
+
|
| 209 |
+
img = ndimage.zoom(img, (depth_factor, width_factor, height_factor), order=0)
|
| 210 |
+
return img
|
| 211 |
+
|
| 212 |
+
def resize_volume_abd(img):
|
| 213 |
+
img[img<=-200] = -200
|
| 214 |
+
img[img>=300] = 300
|
| 215 |
+
|
| 216 |
+
desired_depth = 64
|
| 217 |
+
desired_width = 192
|
| 218 |
+
desired_height = 192
|
| 219 |
+
|
| 220 |
+
current_width = img.shape[0] # [w, h, d]
|
| 221 |
+
current_height = img.shape[1]
|
| 222 |
+
current_depth = img.shape[2]
|
| 223 |
+
|
| 224 |
+
depth = current_depth / desired_depth
|
| 225 |
+
width = current_width / desired_width
|
| 226 |
+
height = current_height / desired_height
|
| 227 |
+
|
| 228 |
+
depth_factor = 1 / depth
|
| 229 |
+
width_factor = 1 / width
|
| 230 |
+
height_factor = 1 / height
|
| 231 |
+
|
| 232 |
+
img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order=0)
|
| 233 |
+
return img
|
| 234 |
+
|
| 235 |
+
def load_and_preprocess_image(image):
|
| 236 |
+
mean = (0.48145466, 0.4578275, 0.40821073)
|
| 237 |
+
std = (0.26862954, 0.26130258, 0.27577711)
|
| 238 |
+
transform = transforms.Compose([
|
| 239 |
+
transforms.Resize([224, 224]),
|
| 240 |
+
transforms.ToTensor(),
|
| 241 |
+
transforms.Normalize(mean, std)
|
| 242 |
+
])
|
| 243 |
+
image = transform(image).type(torch.bfloat16).unsqueeze(0)
|
| 244 |
+
return image
|
| 245 |
+
|
| 246 |
+
def load_and_preprocess_volume(image):
|
| 247 |
+
img = nib.load(image).get_fdata()
|
| 248 |
+
image = torch.from_numpy(resize_volume_abd(img)).permute(2,0,1)
|
| 249 |
+
transform = tio.Compose([
|
| 250 |
+
tio.ZNormalization(masking_method=tio.ZNormalization.mean),
|
| 251 |
+
])
|
| 252 |
+
image = transform(image.unsqueeze(0)).type(torch.bfloat16)
|
| 253 |
+
return image
|
| 254 |
+
|
| 255 |
+
def read_image(image_path):
|
| 256 |
+
if image_path.endswith(('.jpg', '.jpeg', '.png')):
|
| 257 |
+
return load_and_preprocess_image(Image.open(image_path).convert('RGB'))
|
| 258 |
+
elif image_path.endswith('.nii.gz'):
|
| 259 |
+
return load_and_preprocess_volume(image_path)
|
| 260 |
+
else:
|
| 261 |
+
raise ValueError("Unsupported file format")
|
| 262 |
+
|
| 263 |
+
def generate(image_path, image, context, modal, num_imgs, prompt, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature):
|
| 264 |
+
if (len(context) != 0 and ('report' in prompt or 'finding' in prompt or 'impression' in prompt)) or (len(context) != 0 and modal=='derm' and ('diagnosis' in prompt or 'issue' in prompt or 'problem' in prompt)):
|
| 265 |
+
prompt = '<context>' + context + '</context>' + prompt
|
| 266 |
+
if modal == 'ct' and 'segment' in prompt.lower():
|
| 267 |
+
if 'liver' in prompt:
|
| 268 |
+
prompt = 'Segment the liver.'
|
| 269 |
+
if 'spleen' in prompt:
|
| 270 |
+
prompt = 'Segment the spleen.'
|
| 271 |
+
if 'kidney' in prompt:
|
| 272 |
+
prompt = 'Segment the kidney.'
|
| 273 |
+
if 'pancrea' in prompt:
|
| 274 |
+
prompt = 'Segment the pancreas.'
|
| 275 |
+
img_embeds, atts_img, img_embeds_list = model.encode_img(image.unsqueeze(0), [modal])
|
| 276 |
+
placeholder = ['<ImageHere>'] * 9
|
| 277 |
+
prefix = '###Human:' + ''.join([f'<img{i}>' + ''.join(placeholder) + f'</img{i}>' for i in range(num_imgs)])
|
| 278 |
+
img_embeds, atts_img = model.prompt_wrap(img_embeds, atts_img, [prefix], [num_imgs])
|
| 279 |
+
prompt += '###Assistant:'
|
| 280 |
+
prompt_tokens = model.llama_tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(image.device)
|
| 281 |
+
new_img_embeds, new_atts_img = model.prompt_concat(img_embeds, atts_img, prompt_tokens)
|
| 282 |
+
|
| 283 |
+
outputs = model.llama_model.generate(
|
| 284 |
+
inputs_embeds=new_img_embeds,
|
| 285 |
+
max_new_tokens=450,
|
| 286 |
+
stopping_criteria=StoppingCriteriaList([StoppingCriteriaSub(stops=[
|
| 287 |
+
torch.tensor([835]).type(torch.bfloat16).to(image.device),
|
| 288 |
+
torch.tensor([2277, 29937]).type(torch.bfloat16).to(image.device)
|
| 289 |
+
])]),
|
| 290 |
+
num_beams=num_beams,
|
| 291 |
+
do_sample=do_sample,
|
| 292 |
+
min_length=min_length,
|
| 293 |
top_p=top_p,
|
| 294 |
+
repetition_penalty=repetition_penalty,
|
| 295 |
+
length_penalty=length_penalty,
|
| 296 |
+
temperature=temperature,
|
| 297 |
+
output_hidden_states=True,
|
| 298 |
+
return_dict_in_generate=True,
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
hidden_states = outputs.hidden_states
|
| 302 |
+
preds = outputs.sequences[0]
|
| 303 |
+
output_image = None
|
| 304 |
+
seg_mask_2d = None
|
| 305 |
+
seg_mask_3d = None
|
| 306 |
+
if sum(preds == model.seg_token_idx_2d):
|
| 307 |
+
seg_mask = task_seg_2d(model, preds, hidden_states, image)
|
| 308 |
+
output_image, seg_mask_2d = seg_2d_process(image_path, seg_mask)
|
| 309 |
+
if sum(preds == model.seg_token_idx_3d):
|
| 310 |
+
seg_mask = task_seg_3d(model, preds, hidden_states, img_embeds_list)
|
| 311 |
+
output_image, seg_mask_3d = seg_3d_process(image_path, seg_mask)
|
| 312 |
+
if sum(preds == model.det_token_idx):
|
| 313 |
+
det_box = task_det_2d(model, preds, hidden_states)
|
| 314 |
+
output_image = det_2d_process(image_path, det_box)
|
| 315 |
+
|
| 316 |
+
if preds[0] == 0: # Remove unknown token <unk> at the beginning
|
| 317 |
+
preds = preds[1:]
|
| 318 |
+
if preds[0] == 1: # Remove start token <s> at the beginning
|
| 319 |
+
preds = preds[1:]
|
| 320 |
+
|
| 321 |
+
output_text = model.llama_tokenizer.decode(preds, add_special_tokens=False)
|
| 322 |
+
output_text = output_text.split('###')[0].split('Assistant:')[-1].strip()
|
| 323 |
+
|
| 324 |
+
if 'mel' in output_text and modal == 'derm':
|
| 325 |
+
output_text = 'The main diagnosis is melanoma.'
|
| 326 |
+
return output_image, seg_mask_2d, seg_mask_3d, output_text
|
| 327 |
+
|
| 328 |
+
def generate_predictions(images, context, prompt, modality, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature):
|
| 329 |
+
num_imgs = len(images)
|
| 330 |
+
modal = modality.lower()
|
| 331 |
+
image_tensors = [read_image(img).to(device) for img in images]
|
| 332 |
+
if modality == 'ct':
|
| 333 |
+
time.sleep(2)
|
| 334 |
+
else:
|
| 335 |
+
time.sleep(1)
|
| 336 |
+
image_tensor = torch.cat(image_tensors)
|
| 337 |
+
|
| 338 |
+
with torch.autocast(device):
|
| 339 |
+
with torch.no_grad():
|
| 340 |
+
generated_image, seg_mask_2d, seg_mask_3d, output_text = generate(images, image_tensor, context, modal, num_imgs, prompt, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
|
| 341 |
+
|
| 342 |
+
return generated_image, seg_mask_2d, seg_mask_3d, output_text
|
| 343 |
+
|
| 344 |
+
my_dict = {}
|
| 345 |
+
def gradio_interface(chatbot, images, context, prompt, modality, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature):
|
| 346 |
+
global global_images
|
| 347 |
+
if not images:
|
| 348 |
+
image = np.zeros((224, 224, 3), dtype=np.uint8)
|
| 349 |
+
blank_image = Image.fromarray(image)
|
| 350 |
+
snapshot = (blank_image, [])
|
| 351 |
+
global_images = 'none'
|
| 352 |
+
return [(prompt, "At least one image is required to proceed.")], snapshot, gr.update(maximum=0)
|
| 353 |
+
if not prompt or not modality:
|
| 354 |
+
image = np.zeros((224, 224, 3), dtype=np.uint8)
|
| 355 |
+
blank_image = Image.fromarray(image)
|
| 356 |
+
snapshot = (blank_image, [])
|
| 357 |
+
global_images = 'none'
|
| 358 |
+
return [(prompt, "Please provide prompt and modality to proceed.")], snapshot, gr.update(maximum=0)
|
| 359 |
+
|
| 360 |
+
generated_images, seg_mask_2d, seg_mask_3d, output_text = generate_predictions(images, context, prompt, modality, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
|
| 361 |
+
output_images = []
|
| 362 |
+
input_images = [np.asarray(Image.open(img.name).convert('RGB')).astype(np.uint8) if img.name.endswith(('.jpg', '.jpeg', '.png')) else f"{img.name} (3D Volume)" for img in images]
|
| 363 |
+
if generated_images is not None:
|
| 364 |
+
for generated_image in generated_images:
|
| 365 |
+
output_images.append(np.asarray(generated_image).astype(np.uint8))
|
| 366 |
+
snapshot = (output_images[0], [])
|
| 367 |
+
if seg_mask_2d is not None:
|
| 368 |
+
snapshot = (output_images[0], [(seg_mask_2d[0], "Mask")])
|
| 369 |
+
if seg_mask_3d is not None:
|
| 370 |
+
snapshot = (output_images[0], [(seg_mask_3d[0], "Mask")])
|
| 371 |
+
else:
|
| 372 |
+
output_images = input_images.copy()
|
| 373 |
+
snapshot = (output_images[0], [])
|
| 374 |
+
|
| 375 |
+
my_dict['image'] = output_images
|
| 376 |
+
my_dict['mask'] = None
|
| 377 |
+
if seg_mask_2d is not None:
|
| 378 |
+
my_dict['mask'] = seg_mask_2d
|
| 379 |
+
if seg_mask_3d is not None:
|
| 380 |
+
my_dict['mask'] = seg_mask_3d
|
| 381 |
+
|
| 382 |
+
if global_images != images and (global_images is not None):
|
| 383 |
+
chatbot = []
|
| 384 |
+
chatbot.append((prompt, output_text))
|
| 385 |
+
else:
|
| 386 |
+
chatbot.append((prompt, output_text))
|
| 387 |
+
global_images = images
|
| 388 |
+
|
| 389 |
+
return chatbot, snapshot, gr.update(maximum=len(output_images)-1)
|
| 390 |
+
|
| 391 |
+
def render(x):
|
| 392 |
+
if x > len(my_dict['image'])-1:
|
| 393 |
+
x = len(my_dict['image'])-1
|
| 394 |
+
if x < 0:
|
| 395 |
+
x = 0
|
| 396 |
+
image = my_dict['image'][x]
|
| 397 |
+
if my_dict['mask'] is None:
|
| 398 |
+
return (image,[])
|
| 399 |
+
else:
|
| 400 |
+
mask = my_dict['mask'][x]
|
| 401 |
+
value = (image,[(mask, "Mask")])
|
| 402 |
+
return value
|
| 403 |
+
|
| 404 |
+
def update_context_visibility(task):
|
| 405 |
+
if task == "report generation" or task == 'classification':
|
| 406 |
+
return gr.update(visible=True)
|
| 407 |
+
else:
|
| 408 |
+
return gr.update(visible=False)
|
| 409 |
+
|
| 410 |
+
def reset_chatbot():
|
| 411 |
+
return []
|
| 412 |
+
|
| 413 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 414 |
+
# with gr.Row():
|
| 415 |
+
# gr.Markdown("<link href='https://fonts.googleapis.com/css2?family=Libre+Franklin:wght@400;700&display=swap' rel='stylesheet'>")
|
| 416 |
+
gr.Markdown("# MedVersa")
|
| 417 |
+
with gr.Row():
|
| 418 |
+
with gr.Column():
|
| 419 |
+
image_input = gr.File(label="Upload Images", file_count="multiple", file_types=["image", "numpy"])
|
| 420 |
+
# task_input = gr.Dropdown(choices=["report generation", "vqa", "localization", "classification"], label="Task")
|
| 421 |
+
context_input = gr.Textbox(label="Context", placeholder="Enter context here...", lines=3, visible=True)
|
| 422 |
+
modality_input = gr.Dropdown(choices=["cxr", "derm", "ct"], label="Modality")
|
| 423 |
+
prompt_input = gr.Textbox(label="Prompt", placeholder="Enter prompt here... (images should be referred as <img0>, <img1>, ...)", lines=3)
|
| 424 |
+
submit_button = gr.Button("Generate Predictions")
|
| 425 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 426 |
+
num_beams = gr.Slider(label="Number of Beams", minimum=1, maximum=10, step=1, value=1)
|
| 427 |
+
do_sample = gr.Checkbox(label="Do Sample", value=True)
|
| 428 |
+
min_length = gr.Slider(label="Minimum Length", minimum=1, maximum=100, step=1, value=1)
|
| 429 |
+
top_p = gr.Slider(label="Top P", minimum=0.1, maximum=1.0, step=0.1, value=0.9)
|
| 430 |
+
repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.1, value=1.0)
|
| 431 |
+
length_penalty = gr.Slider(label="Length Penalty", minimum=1.0, maximum=2.0, step=0.1, value=1.0)
|
| 432 |
+
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, step=0.1, value=0.1)
|
| 433 |
+
|
| 434 |
+
with gr.Column():
|
| 435 |
+
# output_text = gr.Textbox(label="Generated Text", lines=10, elem_classes="output-textbox")
|
| 436 |
+
chatbot = gr.Chatbot(label="Chatbox")
|
| 437 |
+
slider = gr.Slider(minimum=0, maximum=64, value=1, step=1)
|
| 438 |
+
output_image = gr.AnnotatedImage(height=448, label="Images")
|
| 439 |
+
|
| 440 |
+
# task_input.change(
|
| 441 |
+
# fn=update_context_visibility,
|
| 442 |
+
# inputs=task_input,
|
| 443 |
+
# outputs=context_input
|
| 444 |
+
# )
|
| 445 |
+
|
| 446 |
+
submit_button.click(
|
| 447 |
+
fn=gradio_interface,
|
| 448 |
+
inputs=[chatbot, image_input, context_input, prompt_input, modality_input, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature],
|
| 449 |
+
outputs=[chatbot, output_image, slider]
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
slider.change(
|
| 453 |
+
render,
|
| 454 |
+
inputs=[slider],
|
| 455 |
+
outputs=[output_image],
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
examples = [
|
| 459 |
+
[
|
| 460 |
+
["./demo_ex/c536f749-2326f755-6a65f28f-469affd2-26392ce9.png"],
|
| 461 |
+
"Age:30-40.\nGender:F.\nIndication: ___-year-old female with end-stage renal disease not on dialysis presents with dyspnea. PICC line placement.\nComparison: None.",
|
| 462 |
+
"How would you characterize the findings from <img0>?",
|
| 463 |
+
"cxr",
|
| 464 |
+
],
|
| 465 |
+
[
|
| 466 |
+
["./demo_ex/79eee504-b1b60ab8-5e8dd843-b6ed87aa-670747b1.png"],
|
| 467 |
+
"Age:70-80.\nGender:F.\nIndication: Respiratory distress.\nComparison: None.",
|
| 468 |
+
"How would you characterize the findings from <img0>?",
|
| 469 |
+
"cxr",
|
| 470 |
+
],
|
| 471 |
+
[
|
| 472 |
+
["./demo_ex/f39b05b1-f544e51a-cfe317ca-b66a4aa6-1c1dc22d.png", "./demo_ex/f3fefc29-68544ac8-284b820d-858b5470-f579b982.png"],
|
| 473 |
+
"Age:80-90.\nGender:F.\nIndication: ___-year-old female with history of chest pain.\nComparison: None.",
|
| 474 |
+
"How would you characterize the findings from <img0><img1>?",
|
| 475 |
+
"cxr",
|
| 476 |
+
],
|
| 477 |
+
[
|
| 478 |
+
["./demo_ex/1de015eb-891f1b02-f90be378-d6af1e86-df3270c2.png"],
|
| 479 |
+
"Age:40-50.\nGender:M.\nIndication: ___-year-old male with shortness of breath.\nComparison: None.",
|
| 480 |
+
"How would you characterize the findings from <img0>?",
|
| 481 |
+
"cxr",
|
| 482 |
+
],
|
| 483 |
+
[
|
| 484 |
+
["./demo_ex/bc25fa99-0d3766cc-7704edb7-5c7a4a63-dc65480a.png"],
|
| 485 |
+
"Age:40-50.\nGender:F.\nIndication: History: ___F with tachyacrdia cough doe // infilatrate\nComparison: None.",
|
| 486 |
+
"How would you characterize the findings from <img0>?",
|
| 487 |
+
"cxr",
|
| 488 |
+
],
|
| 489 |
+
[
|
| 490 |
+
["./demo_ex/ISIC_0032258.jpg"],
|
| 491 |
+
"Age:70.\nGender:female.\nLocation:back.",
|
| 492 |
+
"What is primary diagnosis?",
|
| 493 |
+
"derm",
|
| 494 |
+
],
|
| 495 |
+
[
|
| 496 |
+
["./demo_ex/Case_01013_0000.nii.gz"],
|
| 497 |
+
"",
|
| 498 |
+
"Segment the liver.",
|
| 499 |
+
"ct",
|
| 500 |
+
],
|
| 501 |
+
[
|
| 502 |
+
["./demo_ex/Case_00840_0000.nii.gz"],
|
| 503 |
+
"",
|
| 504 |
+
"Segment the liver.",
|
| 505 |
+
"ct",
|
| 506 |
+
],
|
| 507 |
+
]
|
| 508 |
+
|
| 509 |
+
gr.Examples(examples, inputs=[image_input, context_input, prompt_input, modality_input])
|
| 510 |
+
|
| 511 |
+
# Run Gradio app
|
| 512 |
+
demo.launch(share=True, show_error=True)
|