Abo-Doonya commited on
Commit
33fe547
·
verified ·
1 Parent(s): 8f085b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +510 -62
app.py CHANGED
@@ -1,64 +1,512 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
61
-
62
-
63
- if __name__ == "__main__":
64
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import argparse
3
+ import torch
4
+ from torch import cuda
5
+ import torch.nn.functional as F
6
+ import torchvision.transforms.functional as TF
7
+ from torchvision import transforms
8
+ from PIL import Image
9
+ import skimage.morphology, skimage.io
10
+ import cv2
11
+ import numpy as np
12
+ import random
13
+ from transformers import StoppingCriteria, StoppingCriteriaList
14
+ from copy import deepcopy
15
+ from medomni.common.config import Config
16
+ from medomni.common.dist_utils import get_rank
17
+ from medomni.common.registry import registry
18
+ import torchio as tio
19
+ import nibabel as nib
20
+ from scipy import ndimage, misc
21
+ import time
22
+ import ipdb
23
+
24
+ # Function to parse command line arguments
25
+ def parse_args():
26
+ parser = argparse.ArgumentParser(description="Demo")
27
+ parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
28
+ parser.add_argument(
29
+ "--options",
30
+ nargs="+",
31
+ help="override some settings in the used config, the key-value pair in xxx=yyy format will be merged into config file (deprecate), change to --cfg-options instead.",
32
+ )
33
+ args = parser.parse_args()
34
+ return args
35
+
36
+ device = 'cuda' if cuda.is_available() else 'cpu'
37
+ # Launch model
38
+ args = parse_args()
39
+ cfg = Config(args)
40
+
41
+ model_config = cfg.model_cfg
42
+ model_cls = registry.get_model_class(model_config.arch)
43
+ model = model_cls.from_pretrained('hyzhou/MedVersa').to(device).eval()
44
+ global global_images
45
+ global_images = None
46
+
47
+ def seg_2d_process(image_path, pred_mask, img_size=224):
48
+ image = cv2.imread(image_path[0])
49
+ if pred_mask.sum() != 0:
50
+ labels = skimage.morphology.label(pred_mask)
51
+ labelCount = np.bincount(labels.ravel())
52
+ largest_label = np.argmax(labelCount[1:]) + 1
53
+ pred_mask[labels != largest_label] = 0
54
+ pred_mask[labels == largest_label] = 255
55
+ pred_mask = pred_mask.astype(np.uint8)
56
+ contours, _ = cv2.findContours(pred_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
57
+ if contours:
58
+ contours = np.vstack(contours)
59
+ binary_array = np.zeros((img_size, img_size))
60
+ binary_array = cv2.drawContours(binary_array, contours, -1, 255, thickness=cv2.FILLED)
61
+ binary_array = cv2.resize(binary_array, (image.shape[1], image.shape[0]), interpolation = cv2.INTER_NEAREST) / 255
62
+ image = [Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))]
63
+ mask = [binary_array]
64
+ else:
65
+ image = [Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))]
66
+ mask = [np.zeros((image.shape[1], image.shape[0]))]
67
+ else:
68
+ mask = [np.zeros((image.shape[1], image.shape[0]))]
69
+ image = [Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))]
70
+ # output_image = cv2.drawContours(binary_array, contours, -1, (110, 0, 255), 2)
71
+ # output_image_pil = Image.fromarray(cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB))
72
+ return image, mask
73
+
74
+ def seg_3d_process(image_path, seg_mask):
75
+ img = nib.load(image_path[0]).get_fdata()
76
+ image = window_scan(img).transpose(2,0,1).astype(np.uint8)
77
+ if seg_mask.sum() != 0:
78
+ seg_mask = resize_back_volume_abd(seg_mask, image.shape).astype(np.uint8)
79
+ image_slices = []
80
+ contour_slices = []
81
+ for i in range(seg_mask.shape[0]):
82
+ slice_img = np.fliplr(np.rot90(image[i]))
83
+ slice_mask = np.fliplr(np.rot90(seg_mask[i]))
84
+ contours, _ = cv2.findContours(slice_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
85
+ image_slices.append(Image.fromarray(slice_img))
86
+ if contours:
87
+ binary_array = np.zeros(seg_mask.shape[1:])
88
+ binary_array = cv2.drawContours(binary_array, contours, -1, 255, thickness=cv2.FILLED) / 255
89
+ binary_array = cv2.resize(binary_array, slice_img.shape, interpolation = cv2.INTER_NEAREST)
90
+ contour_slices.append(binary_array)
91
+ else:
92
+ contour_slices.append(np.zeros_like(slice_img))
93
+ else:
94
+ image_slices = []
95
+ contour_slices = []
96
+ slice_img = np.fliplr(np.rot90(image[i]))
97
+ image_slices.append(Image.fromarray(slice_img))
98
+ contour_slices.append(np.zeros_like(slice_img))
99
+
100
+ return image_slices, contour_slices
101
+
102
+ def det_2d_process(image_path, box):
103
+ image_slices = []
104
+ image = cv2.imread(image_path[0])
105
+ if box is not None:
106
+ hi,wd,_ = image.shape
107
+ color = tuple(np.random.random(size=3) * 256)
108
+ x1, y1, x2, y2 = int(box[0]*wd), int(box[1]*hi), int(box[2]*wd), int(box[3]*hi)
109
+ image = cv2.rectangle(image, (x1, y1), (x2, y2), color, 10)
110
+ image_slices.append(Image.fromarray(image))
111
+ return image_slices
112
+
113
+ def window_scan(scan, window_center=50, window_width=400):
114
+ """
115
+ Apply windowing to a scan.
116
+
117
+ Parameters:
118
+ scan (numpy.ndarray): 3D numpy array of the CT scan
119
+ window_center (int): The center of the window
120
+ window_width (int): The width of the window
121
+
122
+ Returns:
123
+ numpy.ndarray: Windowed CT scan
124
+ """
125
+ lower_bound = window_center - (window_width // 2)
126
+ upper_bound = window_center + (window_width // 2)
127
+
128
+ windowed_scan = np.clip(scan, lower_bound, upper_bound)
129
+ windowed_scan = (windowed_scan - lower_bound) / (upper_bound - lower_bound)
130
+ windowed_scan = (windowed_scan * 255).astype(np.uint8)
131
+
132
+ return windowed_scan
133
+
134
+ def task_seg_2d(model, preds, hidden_states, image):
135
+ token_mask = preds == model.seg_token_idx_2d
136
+ indices = torch.where(token_mask == True)[0].cpu().numpy()
137
+ feats = model.model_seg_2d.encoder(image.unsqueeze(0)[:, 0])
138
+ last_feats = feats[-1]
139
+ target_states = [hidden_states[ind][-1] for ind in indices]
140
+ if target_states:
141
+ target_states = torch.cat(target_states).squeeze()
142
+ seg_states = model.text2seg_2d(target_states).unsqueeze(0)
143
+ last_feats = last_feats + seg_states.unsqueeze(-1).unsqueeze(-1)
144
+ last_feats = model.text2seg_2d_gn(last_feats)
145
+ feats[-1] = last_feats
146
+ seg_feats = model.model_seg_2d.decoder(*feats)
147
+ seg_preds = model.model_seg_2d.segmentation_head(seg_feats)
148
+ seg_probs = F.sigmoid(seg_preds)
149
+ seg_mask = seg_probs.to(torch.float32).cpu().squeeze().numpy() >= 0.5
150
+ return seg_mask
151
+ else:
152
+ return None
153
+
154
+ def task_seg_3d(model, preds, hidden_states, img_embeds_list):
155
+ new_img_embeds_list = deepcopy(img_embeds_list)
156
+ token_mask = preds == model.seg_token_idx_3d
157
+ indices = torch.where(token_mask == True)[0].cpu().numpy()
158
+ target_states = [hidden_states[ind][-1] for ind in indices]
159
+ if target_states:
160
+ target_states = torch.cat(target_states).squeeze().unsqueeze(0)
161
+ seg_states = model.text2seg_3d(target_states)
162
+ last_feats = new_img_embeds_list[-1]
163
+ last_feats = last_feats + seg_states.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
164
+ last_feats = model.text2seg_3d_gn(last_feats)
165
+ new_img_embeds_list[-1] = last_feats
166
+ seg_preds = model.visual_encoder_3d(encoder_only=False, x_=new_img_embeds_list)
167
+ seg_probs = F.sigmoid(seg_preds)
168
+ seg_mask = seg_probs.to(torch.float32).cpu().squeeze().numpy() >= 0.5
169
+ return seg_mask
170
+
171
+ def task_det_2d(model, preds, hidden_states):
172
+ token_mask = preds == model.det_token_idx
173
+ indices = torch.where(token_mask == True)[0].cpu().numpy()
174
+ target_states = [hidden_states[ind][-1] for ind in indices]
175
+ if target_states:
176
+ target_states = torch.cat(target_states).squeeze()
177
+ det_states = model.text_det(target_states).detach().cpu()
178
+ return det_states.to(torch.float32).numpy()
179
+ return torch.zeros_like(indices)
180
+
181
+ class StoppingCriteriaSub(StoppingCriteria):
182
+ def __init__(self, stops=[]):
183
+ super().__init__()
184
+ self.stops = stops
185
+
186
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
187
+ for stop in self.stops:
188
+ if torch.all((stop == input_ids[0][-len(stop):])).item():
189
+ return True
190
+ return False
191
+
192
+ def resize_back_volume_abd(img, target_size):
193
+ desired_depth = target_size[0]
194
+ desired_width = target_size[1]
195
+ desired_height = target_size[2]
196
+
197
+ current_depth = img.shape[0] # [d, w, h]
198
+ current_width = img.shape[1]
199
+ current_height = img.shape[2]
200
+
201
+ depth = current_depth / desired_depth
202
+ width = current_width / desired_width
203
+ height = current_height / desired_height
204
+
205
+ depth_factor = 1 / depth
206
+ width_factor = 1 / width
207
+ height_factor = 1 / height
208
+
209
+ img = ndimage.zoom(img, (depth_factor, width_factor, height_factor), order=0)
210
+ return img
211
+
212
+ def resize_volume_abd(img):
213
+ img[img<=-200] = -200
214
+ img[img>=300] = 300
215
+
216
+ desired_depth = 64
217
+ desired_width = 192
218
+ desired_height = 192
219
+
220
+ current_width = img.shape[0] # [w, h, d]
221
+ current_height = img.shape[1]
222
+ current_depth = img.shape[2]
223
+
224
+ depth = current_depth / desired_depth
225
+ width = current_width / desired_width
226
+ height = current_height / desired_height
227
+
228
+ depth_factor = 1 / depth
229
+ width_factor = 1 / width
230
+ height_factor = 1 / height
231
+
232
+ img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order=0)
233
+ return img
234
+
235
+ def load_and_preprocess_image(image):
236
+ mean = (0.48145466, 0.4578275, 0.40821073)
237
+ std = (0.26862954, 0.26130258, 0.27577711)
238
+ transform = transforms.Compose([
239
+ transforms.Resize([224, 224]),
240
+ transforms.ToTensor(),
241
+ transforms.Normalize(mean, std)
242
+ ])
243
+ image = transform(image).type(torch.bfloat16).unsqueeze(0)
244
+ return image
245
+
246
+ def load_and_preprocess_volume(image):
247
+ img = nib.load(image).get_fdata()
248
+ image = torch.from_numpy(resize_volume_abd(img)).permute(2,0,1)
249
+ transform = tio.Compose([
250
+ tio.ZNormalization(masking_method=tio.ZNormalization.mean),
251
+ ])
252
+ image = transform(image.unsqueeze(0)).type(torch.bfloat16)
253
+ return image
254
+
255
+ def read_image(image_path):
256
+ if image_path.endswith(('.jpg', '.jpeg', '.png')):
257
+ return load_and_preprocess_image(Image.open(image_path).convert('RGB'))
258
+ elif image_path.endswith('.nii.gz'):
259
+ return load_and_preprocess_volume(image_path)
260
+ else:
261
+ raise ValueError("Unsupported file format")
262
+
263
+ def generate(image_path, image, context, modal, num_imgs, prompt, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature):
264
+ if (len(context) != 0 and ('report' in prompt or 'finding' in prompt or 'impression' in prompt)) or (len(context) != 0 and modal=='derm' and ('diagnosis' in prompt or 'issue' in prompt or 'problem' in prompt)):
265
+ prompt = '<context>' + context + '</context>' + prompt
266
+ if modal == 'ct' and 'segment' in prompt.lower():
267
+ if 'liver' in prompt:
268
+ prompt = 'Segment the liver.'
269
+ if 'spleen' in prompt:
270
+ prompt = 'Segment the spleen.'
271
+ if 'kidney' in prompt:
272
+ prompt = 'Segment the kidney.'
273
+ if 'pancrea' in prompt:
274
+ prompt = 'Segment the pancreas.'
275
+ img_embeds, atts_img, img_embeds_list = model.encode_img(image.unsqueeze(0), [modal])
276
+ placeholder = ['<ImageHere>'] * 9
277
+ prefix = '###Human:' + ''.join([f'<img{i}>' + ''.join(placeholder) + f'</img{i}>' for i in range(num_imgs)])
278
+ img_embeds, atts_img = model.prompt_wrap(img_embeds, atts_img, [prefix], [num_imgs])
279
+ prompt += '###Assistant:'
280
+ prompt_tokens = model.llama_tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(image.device)
281
+ new_img_embeds, new_atts_img = model.prompt_concat(img_embeds, atts_img, prompt_tokens)
282
+
283
+ outputs = model.llama_model.generate(
284
+ inputs_embeds=new_img_embeds,
285
+ max_new_tokens=450,
286
+ stopping_criteria=StoppingCriteriaList([StoppingCriteriaSub(stops=[
287
+ torch.tensor([835]).type(torch.bfloat16).to(image.device),
288
+ torch.tensor([2277, 29937]).type(torch.bfloat16).to(image.device)
289
+ ])]),
290
+ num_beams=num_beams,
291
+ do_sample=do_sample,
292
+ min_length=min_length,
293
  top_p=top_p,
294
+ repetition_penalty=repetition_penalty,
295
+ length_penalty=length_penalty,
296
+ temperature=temperature,
297
+ output_hidden_states=True,
298
+ return_dict_in_generate=True,
299
+ )
300
+
301
+ hidden_states = outputs.hidden_states
302
+ preds = outputs.sequences[0]
303
+ output_image = None
304
+ seg_mask_2d = None
305
+ seg_mask_3d = None
306
+ if sum(preds == model.seg_token_idx_2d):
307
+ seg_mask = task_seg_2d(model, preds, hidden_states, image)
308
+ output_image, seg_mask_2d = seg_2d_process(image_path, seg_mask)
309
+ if sum(preds == model.seg_token_idx_3d):
310
+ seg_mask = task_seg_3d(model, preds, hidden_states, img_embeds_list)
311
+ output_image, seg_mask_3d = seg_3d_process(image_path, seg_mask)
312
+ if sum(preds == model.det_token_idx):
313
+ det_box = task_det_2d(model, preds, hidden_states)
314
+ output_image = det_2d_process(image_path, det_box)
315
+
316
+ if preds[0] == 0: # Remove unknown token <unk> at the beginning
317
+ preds = preds[1:]
318
+ if preds[0] == 1: # Remove start token <s> at the beginning
319
+ preds = preds[1:]
320
+
321
+ output_text = model.llama_tokenizer.decode(preds, add_special_tokens=False)
322
+ output_text = output_text.split('###')[0].split('Assistant:')[-1].strip()
323
+
324
+ if 'mel' in output_text and modal == 'derm':
325
+ output_text = 'The main diagnosis is melanoma.'
326
+ return output_image, seg_mask_2d, seg_mask_3d, output_text
327
+
328
+ def generate_predictions(images, context, prompt, modality, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature):
329
+ num_imgs = len(images)
330
+ modal = modality.lower()
331
+ image_tensors = [read_image(img).to(device) for img in images]
332
+ if modality == 'ct':
333
+ time.sleep(2)
334
+ else:
335
+ time.sleep(1)
336
+ image_tensor = torch.cat(image_tensors)
337
+
338
+ with torch.autocast(device):
339
+ with torch.no_grad():
340
+ generated_image, seg_mask_2d, seg_mask_3d, output_text = generate(images, image_tensor, context, modal, num_imgs, prompt, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
341
+
342
+ return generated_image, seg_mask_2d, seg_mask_3d, output_text
343
+
344
+ my_dict = {}
345
+ def gradio_interface(chatbot, images, context, prompt, modality, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature):
346
+ global global_images
347
+ if not images:
348
+ image = np.zeros((224, 224, 3), dtype=np.uint8)
349
+ blank_image = Image.fromarray(image)
350
+ snapshot = (blank_image, [])
351
+ global_images = 'none'
352
+ return [(prompt, "At least one image is required to proceed.")], snapshot, gr.update(maximum=0)
353
+ if not prompt or not modality:
354
+ image = np.zeros((224, 224, 3), dtype=np.uint8)
355
+ blank_image = Image.fromarray(image)
356
+ snapshot = (blank_image, [])
357
+ global_images = 'none'
358
+ return [(prompt, "Please provide prompt and modality to proceed.")], snapshot, gr.update(maximum=0)
359
+
360
+ generated_images, seg_mask_2d, seg_mask_3d, output_text = generate_predictions(images, context, prompt, modality, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature)
361
+ output_images = []
362
+ input_images = [np.asarray(Image.open(img.name).convert('RGB')).astype(np.uint8) if img.name.endswith(('.jpg', '.jpeg', '.png')) else f"{img.name} (3D Volume)" for img in images]
363
+ if generated_images is not None:
364
+ for generated_image in generated_images:
365
+ output_images.append(np.asarray(generated_image).astype(np.uint8))
366
+ snapshot = (output_images[0], [])
367
+ if seg_mask_2d is not None:
368
+ snapshot = (output_images[0], [(seg_mask_2d[0], "Mask")])
369
+ if seg_mask_3d is not None:
370
+ snapshot = (output_images[0], [(seg_mask_3d[0], "Mask")])
371
+ else:
372
+ output_images = input_images.copy()
373
+ snapshot = (output_images[0], [])
374
+
375
+ my_dict['image'] = output_images
376
+ my_dict['mask'] = None
377
+ if seg_mask_2d is not None:
378
+ my_dict['mask'] = seg_mask_2d
379
+ if seg_mask_3d is not None:
380
+ my_dict['mask'] = seg_mask_3d
381
+
382
+ if global_images != images and (global_images is not None):
383
+ chatbot = []
384
+ chatbot.append((prompt, output_text))
385
+ else:
386
+ chatbot.append((prompt, output_text))
387
+ global_images = images
388
+
389
+ return chatbot, snapshot, gr.update(maximum=len(output_images)-1)
390
+
391
+ def render(x):
392
+ if x > len(my_dict['image'])-1:
393
+ x = len(my_dict['image'])-1
394
+ if x < 0:
395
+ x = 0
396
+ image = my_dict['image'][x]
397
+ if my_dict['mask'] is None:
398
+ return (image,[])
399
+ else:
400
+ mask = my_dict['mask'][x]
401
+ value = (image,[(mask, "Mask")])
402
+ return value
403
+
404
+ def update_context_visibility(task):
405
+ if task == "report generation" or task == 'classification':
406
+ return gr.update(visible=True)
407
+ else:
408
+ return gr.update(visible=False)
409
+
410
+ def reset_chatbot():
411
+ return []
412
+
413
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
414
+ # with gr.Row():
415
+ # gr.Markdown("<link href='https://fonts.googleapis.com/css2?family=Libre+Franklin:wght@400;700&display=swap' rel='stylesheet'>")
416
+ gr.Markdown("# MedVersa")
417
+ with gr.Row():
418
+ with gr.Column():
419
+ image_input = gr.File(label="Upload Images", file_count="multiple", file_types=["image", "numpy"])
420
+ # task_input = gr.Dropdown(choices=["report generation", "vqa", "localization", "classification"], label="Task")
421
+ context_input = gr.Textbox(label="Context", placeholder="Enter context here...", lines=3, visible=True)
422
+ modality_input = gr.Dropdown(choices=["cxr", "derm", "ct"], label="Modality")
423
+ prompt_input = gr.Textbox(label="Prompt", placeholder="Enter prompt here... (images should be referred as <img0>, <img1>, ...)", lines=3)
424
+ submit_button = gr.Button("Generate Predictions")
425
+ with gr.Accordion("Advanced Settings", open=False):
426
+ num_beams = gr.Slider(label="Number of Beams", minimum=1, maximum=10, step=1, value=1)
427
+ do_sample = gr.Checkbox(label="Do Sample", value=True)
428
+ min_length = gr.Slider(label="Minimum Length", minimum=1, maximum=100, step=1, value=1)
429
+ top_p = gr.Slider(label="Top P", minimum=0.1, maximum=1.0, step=0.1, value=0.9)
430
+ repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.1, value=1.0)
431
+ length_penalty = gr.Slider(label="Length Penalty", minimum=1.0, maximum=2.0, step=0.1, value=1.0)
432
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, step=0.1, value=0.1)
433
+
434
+ with gr.Column():
435
+ # output_text = gr.Textbox(label="Generated Text", lines=10, elem_classes="output-textbox")
436
+ chatbot = gr.Chatbot(label="Chatbox")
437
+ slider = gr.Slider(minimum=0, maximum=64, value=1, step=1)
438
+ output_image = gr.AnnotatedImage(height=448, label="Images")
439
+
440
+ # task_input.change(
441
+ # fn=update_context_visibility,
442
+ # inputs=task_input,
443
+ # outputs=context_input
444
+ # )
445
+
446
+ submit_button.click(
447
+ fn=gradio_interface,
448
+ inputs=[chatbot, image_input, context_input, prompt_input, modality_input, num_beams, do_sample, min_length, top_p, repetition_penalty, length_penalty, temperature],
449
+ outputs=[chatbot, output_image, slider]
450
+ )
451
+
452
+ slider.change(
453
+ render,
454
+ inputs=[slider],
455
+ outputs=[output_image],
456
+ )
457
+
458
+ examples = [
459
+ [
460
+ ["./demo_ex/c536f749-2326f755-6a65f28f-469affd2-26392ce9.png"],
461
+ "Age:30-40.\nGender:F.\nIndication: ___-year-old female with end-stage renal disease not on dialysis presents with dyspnea. PICC line placement.\nComparison: None.",
462
+ "How would you characterize the findings from <img0>?",
463
+ "cxr",
464
+ ],
465
+ [
466
+ ["./demo_ex/79eee504-b1b60ab8-5e8dd843-b6ed87aa-670747b1.png"],
467
+ "Age:70-80.\nGender:F.\nIndication: Respiratory distress.\nComparison: None.",
468
+ "How would you characterize the findings from <img0>?",
469
+ "cxr",
470
+ ],
471
+ [
472
+ ["./demo_ex/f39b05b1-f544e51a-cfe317ca-b66a4aa6-1c1dc22d.png", "./demo_ex/f3fefc29-68544ac8-284b820d-858b5470-f579b982.png"],
473
+ "Age:80-90.\nGender:F.\nIndication: ___-year-old female with history of chest pain.\nComparison: None.",
474
+ "How would you characterize the findings from <img0><img1>?",
475
+ "cxr",
476
+ ],
477
+ [
478
+ ["./demo_ex/1de015eb-891f1b02-f90be378-d6af1e86-df3270c2.png"],
479
+ "Age:40-50.\nGender:M.\nIndication: ___-year-old male with shortness of breath.\nComparison: None.",
480
+ "How would you characterize the findings from <img0>?",
481
+ "cxr",
482
+ ],
483
+ [
484
+ ["./demo_ex/bc25fa99-0d3766cc-7704edb7-5c7a4a63-dc65480a.png"],
485
+ "Age:40-50.\nGender:F.\nIndication: History: ___F with tachyacrdia cough doe // infilatrate\nComparison: None.",
486
+ "How would you characterize the findings from <img0>?",
487
+ "cxr",
488
+ ],
489
+ [
490
+ ["./demo_ex/ISIC_0032258.jpg"],
491
+ "Age:70.\nGender:female.\nLocation:back.",
492
+ "What is primary diagnosis?",
493
+ "derm",
494
+ ],
495
+ [
496
+ ["./demo_ex/Case_01013_0000.nii.gz"],
497
+ "",
498
+ "Segment the liver.",
499
+ "ct",
500
+ ],
501
+ [
502
+ ["./demo_ex/Case_00840_0000.nii.gz"],
503
+ "",
504
+ "Segment the liver.",
505
+ "ct",
506
+ ],
507
+ ]
508
+
509
+ gr.Examples(examples, inputs=[image_input, context_input, prompt_input, modality_input])
510
+
511
+ # Run Gradio app
512
+ demo.launch(share=True, show_error=True)