| import torch | |
| import logging | |
| from torch import Tensor | |
| from typing import Mapping | |
| def _setup_logger(): | |
| log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s") | |
| logger = logging.getLogger() | |
| logger.setLevel(logging.INFO) | |
| console_handler = logging.StreamHandler() | |
| console_handler.setFormatter(log_format) | |
| logger.handlers = [console_handler] | |
| return logger | |
| logger = _setup_logger() | |
| def move_to_cuda(sample): | |
| if len(sample) == 0: | |
| return {} | |
| def _move_to_cuda(maybe_tensor): | |
| if torch.is_tensor(maybe_tensor): | |
| return maybe_tensor.cuda(non_blocking=True) | |
| elif isinstance(maybe_tensor, dict): | |
| return {key: _move_to_cuda(value) for key, value in maybe_tensor.items()} | |
| elif isinstance(maybe_tensor, list): | |
| return [_move_to_cuda(x) for x in maybe_tensor] | |
| elif isinstance(maybe_tensor, tuple): | |
| return tuple([_move_to_cuda(x) for x in maybe_tensor]) | |
| elif isinstance(maybe_tensor, Mapping): | |
| return type(maybe_tensor)({k: _move_to_cuda(v) for k, v in maybe_tensor.items()}) | |
| else: | |
| return maybe_tensor | |
| return _move_to_cuda(sample) | |
| def pool(last_hidden_states: Tensor, | |
| attention_mask: Tensor, | |
| pool_type: str) -> Tensor: | |
| last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) | |
| if pool_type == "avg": | |
| emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] | |
| elif pool_type == "cls": | |
| emb = last_hidden[:, 0] | |
| else: | |
| raise ValueError(f"pool_type {pool_type} not supported") | |
| return emb |