| |
|
| | from PIL import Image |
| | import requests |
| | import torch |
| | import torchvision.transforms as transforms |
| | transformten = transforms.Compose([ |
| | transforms.Resize((224, 224)), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[0.485, 0.456, 0.406], |
| | std=[0.229, 0.224, 0.225]) |
| | ]) |
| | from collections import defaultdict |
| | from torch.utils.data import DataLoader |
| | import os |
| | from transformers import AutoTokenizer |
| |
|
| | image_cache = {} |
| |
|
| | def preprocess_image(image_source): |
| | """ |
| | Preprocess a single image for inference. |
| | `image_source` can be either a URL or a local file path. |
| | Returns a tensor [C, H, W]. |
| | """ |
| | if isinstance(image_source, str): |
| | if image_source.startswith("http"): |
| | image = Image.open(requests.get(image_source, stream=True).raw).convert("RGB") |
| | else: |
| | image = Image.open(image_source).convert("RGB") |
| | elif isinstance(image_source, Image.Image): |
| | image = image_source |
| | else: |
| | raise ValueError("Unsupported image_source type") |
| |
|
| | |
| | image = transformten(image) |
| |
|
| | return image |
| | |
| | def preprocess_example(example): |
| | |
| | |
| |
|
| | router_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") |
| |
|
| | |
| | image_name = example["image"].split("/")[-1] |
| | image_path = os.path.join("/kaggle/input/medico2025", image_name) |
| |
|
| | |
| | if image_path in image_cache: |
| | image = image_cache[image_path] |
| | |
| | else: |
| | image = Image.open(image_path) |
| | if image.mode != 'RGB': |
| | image = image.convert('RGB') |
| | image_cache[image_path] = image |
| | |
| | |
| | image = transformten(image) |
| |
|
| |
|
| | |
| |
|
| | |
| | q_inputs = router_tokenizer(example["question"], |
| | return_tensors="pt", |
| | truncation=True, |
| | padding="max_length", |
| | max_length=32) |
| |
|
| | |
| | input_ids = q_inputs["input_ids"].squeeze(0) |
| | attention_mask = q_inputs["attention_mask"].squeeze(0) |
| | |
| | |
| | return { |
| | "image": image, |
| | "input_ids": input_ids, |
| | "attention_mask": attention_mask, |
| | "answer": example["answer"], |
| | "question_class": example["question_class"], |
| | "image_url": example["image"], |
| | } |
| |
|
| | def normalize_answer(ans, q_type): |
| | ans = ans.strip().lower() |
| |
|
| | if q_type == "yesno": |
| | if "yes" in ans or "present" in ans or "evidence" in ans: |
| | return "Yes" |
| | elif "no" in ans or "absent" in ans or "none" in ans: |
| | return "No" |
| | else: |
| | return None |
| |
|
| | if q_type == "count": |
| | |
| | from re import findall |
| | numbers = findall(r"\d+", ans) |
| | if numbers: |
| | return numbers[0] |
| | elif "one" in ans: return "1" |
| | elif "two" in ans: return "2" |
| | return None |
| |
|
| | if q_type == "color": |
| | for color in ["red","green","yellow","blue","white","black"]: |
| | if color in ans: |
| | return color |
| | return None |
| |
|
| | if q_type == "location": |
| | |
| | for loc in ["upper","lower","left","right","central"]: |
| | if loc in ans: |
| | return loc |
| | return None |
| |
|
| | if q_type in ["single","multi"]: |
| | return ans |
| |
|
| | return ans |
| |
|
| |
|
| | def build_vocabs(dataset,q_types_mapping): |
| | |
| | task_vocabs = {} |
| | for general_class in set(q_types_mapping.values()): |
| | task_vocabs[general_class] = {} |
| | |
| | for row in dataset: |
| | fine_class = row["question_class"] |
| |
|
| | |
| | if isinstance(fine_class, list): |
| | fine_class = fine_class[0] |
| |
|
| | general_class = q_types_mapping[fine_class] |
| |
|
| | norm_ans = normalize_answer(row["answer"], general_class) |
| | if norm_ans is None: |
| | continue |
| |
|
| | if norm_ans not in task_vocabs[general_class]: |
| | idx = len(task_vocabs[general_class]) |
| | task_vocabs[general_class][norm_ans] = idx |
| |
|
| | return task_vocabs |
| |
|
| |
|
| | def build_answer_vocab(dataset, q_types_mapping): |
| | answer_vocab = defaultdict(dict) |
| | counters = defaultdict(int) |
| |
|
| | for ans, q_class in zip(dataset["answer"], dataset["question_class"]): |
| | |
| | if isinstance(q_class, list): |
| | q_class = q_class[0] |
| |
|
| | general_class = q_types_mapping[q_class] |
| |
|
| | if ans not in answer_vocab[general_class]: |
| | answer_vocab[general_class][ans] = counters[general_class] |
| | counters[general_class] += 1 |
| |
|
| | return answer_vocab |
| |
|
| |
|
| |
|
| | def collate_fn(batch): |
| | |
| | |
| | |
| | images = torch.stack([torch.tensor(item["image"]) if isinstance(item["image"], list) else item["image"] for item in batch]) |
| | |
| | |
| |
|
| |
|
| | input_ids = torch.stack([torch.tensor(item["input_ids"]) if isinstance(item["input_ids"], list) else item["input_ids"] for item in batch]) |
| | attention_mask = torch.stack([torch.tensor(item["attention_mask"]) if isinstance(item["attention_mask"], list) else item["attention_mask"] for item in batch]) |
| |
|
| |
|
| | |
| | |
| | |
| | answers = [item["answer"] for item in batch] |
| | q_classes = [item["question_class"] for item in batch] |
| | return { |
| | "images": images, |
| | "input_ids": input_ids, |
| | "attention_mask": attention_mask, |
| | "answers": answers, |
| | "question_classes": q_classes, |
| | } |
| |
|
| |
|