| import torch |
| from transformers import AutoTokenizer, BatchEncoding |
|
|
| from mixinhelpers import CXR_Mixin, ECG_Mixin, ECHO_Mixin, Text_Mixin |
|
|
| """ |
| Preprocessor classes for different modalities and their combinations. |
| You can combine different mixins to create preprocessors for multi-modal inputs. |
| Examples below are provided for ECHO+Text, ECG+Text, and CXR+Text. |
| """ |
|
|
|
|
| class BasePreprocessor: |
| def __init__(self, model_name: str = "dmis-lab/biobert-v1.1") -> None: |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
| |
| class ECHOText_Preprocessor(BasePreprocessor, ECHO_Mixin, Text_Mixin): |
| def __init__(self, model_name: str = "dmis-lab/biobert-v1.1") -> None: |
| super().__init__(model_name=model_name) |
|
|
| def preprocess_echo_text(self, echo_path: str, text: str) -> tuple[torch.Tensor, BatchEncoding]: |
| """this can be used in dataloader to correctly collate batches, use the string keys to |
| identify the modalities |
| echo_path: path to echo npy file |
| text: string of text report |
| returns: (echo tensor, tokenized text dict)""" |
| echo = self.preprocess_single_echo(echo_path) |
| text_inputs = self.construct_caption( |
| caption=text, tokenizer=self.tokenizer, modality=self.ECHO_KEY |
| ) |
| return echo, text_inputs |
|
|
|
|
| class ECGText_Preprocessor(BasePreprocessor, ECG_Mixin, Text_Mixin): |
| def __init__(self, model_name: str = "dmis-lab/biobert-v1.1") -> None: |
| super().__init__(model_name=model_name) |
|
|
| def preprocess_ecg_text(self, ecg_path: str, text: str) -> tuple[torch.Tensor, BatchEncoding]: |
| """this can be used in dataloader to correctly collate batches, use the string keys |
| to identify the modalities |
| ecg_path: path to ecg npy file |
| text: string of text report |
| returns: (ecg tensor, tokenized text dict)""" |
| ecg = self.preprocess_single_ecg(ecg_path) |
| text_inputs = self.construct_caption( |
| caption=text, tokenizer=self.tokenizer, modality=self.ECG_KEY |
| ) |
|
|
| return ecg, text_inputs |
|
|
|
|
| class CXRText_Preprocessor(BasePreprocessor, CXR_Mixin, Text_Mixin): |
| def __init__(self, model_name: str = "dmis-lab/biobert-v1.1") -> None: |
| super().__init__(model_name=model_name) |
|
|
| def preprocess_cxr_text(self, cxr_path: str, text: str) -> tuple[torch.Tensor, BatchEncoding]: |
| """this can be used in dataloader to correctly collate batches, use the string keys to |
| identify the modalities |
| cxr_path: path to cxr image file |
| text: string of text report |
| returns: (cxr tensor, tokenized text dict)""" |
| cxr = self.preprocess_single_cxr(cxr_path) |
| text_inputs = self.construct_caption( |
| caption=text, tokenizer=self.tokenizer, modality=self.VISION_KEY |
| ) |
|
|
| return cxr, text_inputs |
|
|