Hugo Farajallah
chore(dataset_process): renames it to "aligner" as it better describes its role.
b57fa93
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| import transformers | |
| import common | |
| def encode_phonemes( | |
| phonemes: list[str] | str, | |
| tokenizer: transformers.Wav2Vec2Tokenizer | |
| ) -> torch.Tensor: | |
| """ | |
| From list of phonemes to a logits-like matrix. | |
| :param list[str] | str phonemes: List of individuals phonemes to encode | |
| :param tokenizer: The tokenizer to use. | |
| :return torch.Tensor: Encodings for the phonemes, size (1, len(phonemes), n_tokens) | |
| """ | |
| encodings = torch.zeros( | |
| (1, len(phonemes), len(tokenizer.encoder) + 2), | |
| dtype=torch.uint8 | |
| ) | |
| label_ids = tokenizer.encode(phonemes) | |
| for i, label_id in enumerate(label_ids): | |
| encodings[0, i, label_id] = 1 | |
| return encodings | |
| def l2_logit_norm(prediction, target): | |
| """ | |
| Apply L2 distance between two vectors. | |
| Is close to 0 for two similar vectors, close to 1 for different vectors | |
| """ | |
| val = torch.norm(prediction - target) / 1.414 | |
| return val | |
| def cosine_similarity(prediction, target): | |
| normed = torch.softmax(prediction, 0) | |
| similarity = (1 - torch.nn.CosineSimilarity(dim=0)(normed, target)) / 2 | |
| return similarity # * torch.norm(prediction) | |
| def argmax_selection(prediction, target): | |
| """ | |
| Select normalized(prediction)[argmax(target)] | |
| 0 for same vectors, 1 for totally different | |
| """ | |
| return 1 - prediction[torch.argmax(target)] | |
| def plot_metric(metric, prediction, target): | |
| """Plot the result of a metric.""" | |
| fig, ax = plt.subplots() | |
| _model, processor = common.get_model() | |
| predicted_labels = ( | |
| processor.decode(i) if i < prediction.shape[0] - 4 else "" for i in range(prediction.shape[0]) | |
| ) | |
| normed = torch.softmax(prediction, 0) | |
| ax.plot(normed, label="Normed prediction") | |
| ax.scatter([torch.argmax(target).item()], [1], label="Target", marker="X") | |
| ax.set_xticks(range(prediction.shape[0]), predicted_labels) | |
| value = 1 - metric(prediction, target) | |
| ax.plot([0, normed.shape[0]], [value, value], label="1 - Metric value (1 = perfect)") | |
| plt.legend() | |
| plt.show() | |
| def compute_path_matrix(prediction, target, metric, insertion_cost, deletion_cost): | |
| """Compute the alignment matrix of two matrices.""" | |
| # Define the matrix | |
| path_matrix = torch.empty((prediction.shape[1], target.shape[1])) | |
| # Now run recursively | |
| for i, pred_column in enumerate(prediction[0]): | |
| for j, target_column in enumerate(target[0]): | |
| if i == 0 and j == 0: | |
| path_matrix[i, j] = 0 | |
| elif i == 0: | |
| path_matrix[0, j] = j * insertion_cost | |
| elif j == 0: | |
| path_matrix[i, 0] = i * deletion_cost | |
| else: | |
| # plot_metric(metric, pred_column, target_column) | |
| path_matrix[i, j] = min( | |
| path_matrix[i - 1, j - 1] + metric(pred_column, target_column), | |
| path_matrix[i - 1, j] + deletion_cost, | |
| path_matrix[i, j - 1] + insertion_cost | |
| ) | |
| return path_matrix | |
| def solve_path(prediction, target, path_matrix): | |
| """ | |
| Find the matching path between a prediction, a target and a path matrix. | |
| For each step we minimize the cost. | |
| """ | |
| line, col = prediction.shape[1] - 1, target.shape[1] - 1 | |
| matching = [] | |
| while line > 0 or col > 0: | |
| matching.append((line, col)) | |
| directions = [] | |
| if line > 0 and col > 0: | |
| directions.append((line - 1, col - 1)) | |
| if line > 0: | |
| directions.append((line - 1, col)) | |
| if col > 0: | |
| directions.append((line, col - 1)) | |
| best_score = float("inf") | |
| dir_index = -1 | |
| for i, direction in enumerate(directions): | |
| if path_matrix[direction[0]][direction[1]] < best_score: | |
| best_score = path_matrix[direction[0]][direction[1]] | |
| dir_index = i | |
| line, col = directions[dir_index] | |
| matching.reverse() | |
| return matching | |
| def display_matrix_result(path_matrix, matching, prediction, target, processor=None): | |
| """Display all the information resulting from a Bellman matching of matrices. | |
| Returns the figure instead of showing it directly for use in Gradio. | |
| """ | |
| fig, axis = plt.subplots(figsize=(12, 8)) | |
| if processor is None: | |
| _model, processor = common.get_model() | |
| # Display the matrix | |
| im = axis.matshow(path_matrix.T, aspect="auto", cmap='Blues') | |
| cbar = plt.colorbar(im, ax=axis) | |
| cbar.set_label('Alignment Cost', rotation=270, labelpad=20, fontsize=11) | |
| # Set the labels for the axes with clearer names | |
| axis.set_xlabel('Predicted Phoneme Sequence', fontsize=12) | |
| axis.set_ylabel('Target Phoneme Sequence', fontsize=12) | |
| axis.set_title('Phoneme Alignment Matrix\n(Blue = Lower Cost, Red Line = Optimal Path)', | |
| fontsize=14, pad=20) | |
| # Get phoneme labels for both axes | |
| predicted_labels = tuple(map(processor.decode, torch.argmax(prediction, -1)[0])) | |
| target_labels = tuple(map(processor.decode, torch.argmax(target, -1)[0])) | |
| # Set x-axis ticks (predicted phonemes) | |
| non_empty_pred_indices = [i for i, label in enumerate(predicted_labels) if label not in ("", "[PAD]")] | |
| non_empty_pred_labels = [label for i, label in enumerate(predicted_labels) if label not in ("", "[PAD]")] | |
| if non_empty_pred_indices: | |
| axis.set_xticks(non_empty_pred_indices) | |
| axis.set_xticklabels(non_empty_pred_labels, rotation=45, ha='right', fontsize=10) | |
| # Set y-axis ticks (target phonemes) | |
| non_empty_target_indices = [i for i, label in enumerate(target_labels) if label not in ("", "[PAD]")] | |
| non_empty_target_labels = [label for i, label in enumerate(target_labels) if label not in ("", "[PAD]")] | |
| if non_empty_target_indices: | |
| axis.set_yticks(non_empty_target_indices) | |
| axis.set_yticklabels(non_empty_target_labels, fontsize=10) | |
| # Add subtle grid | |
| axis.grid(which="major", color="gray", alpha=0.2, linestyle="-") | |
| # Plot the optimal path in red with better visibility | |
| if matching: | |
| axis.plot( | |
| [val[0] for val in matching], | |
| [val[1] for val in matching], | |
| color="red", | |
| linewidth=3, | |
| marker='o', | |
| markersize=4, | |
| markerfacecolor='white', | |
| markeredgecolor='red', | |
| markeredgewidth=2, | |
| label="Optimal Alignment Path", | |
| alpha=0.9 | |
| ) | |
| # Add legend with better positioning | |
| axis.legend(loc='upper right', bbox_to_anchor=(1.0, 1.0), fontsize=11) | |
| # Add text annotations for better understanding | |
| axis.text( | |
| 0.02, 0.98, 'Lower values indicate\nbetter alignment', | |
| transform=axis.transAxes, fontsize=9, va='top', ha='left', | |
| bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8) | |
| ) | |
| plt.tight_layout() | |
| return fig | |
| def bellman_matching(prediction, target, insertion_cost=1.3, deletion_cost=3, metric=l2_logit_norm): | |
| """ | |
| Match to sequences with Bellman's algorithm. | |
| :param prediction: Actual prediction | |
| :param target: Target list of values. | |
| :param float insertion_cost: Something was added in prediction. | |
| :param float deletion_cost: Something was missing in prediction. | |
| :param Callable metric: The metric to use. | |
| :return tuple(list, float): Best alignment [(prediction[i], target[j]), ...] for all elements, and its score | |
| """ | |
| # Add padding: start matching on letters (do not penalize kids starting with insertions or long audio) | |
| padded_target = torch.zeros((target.shape[0], target.shape[1] + 1, target.shape[2])) | |
| padded_target[0, 1:] = target | |
| padded_prediction = torch.zeros((prediction.shape[0], prediction.shape[1] + 1, prediction.shape[2])) | |
| padded_prediction[0, 1:] = prediction | |
| path_matrix = compute_path_matrix( | |
| padded_prediction, padded_target, metric, | |
| insertion_cost, | |
| deletion_cost | |
| ) | |
| # Now solve path, find candidate diagonal | |
| padded_matching = solve_path(padded_prediction, padded_target, path_matrix) | |
| short_matching = [] | |
| for match in padded_matching: | |
| if match[0] == 0 or match[1] == 0: | |
| continue | |
| short_matching.append((match[0] - 1, match[1] - 1)) | |
| if match[1] == padded_target.shape[1] - 1: | |
| break | |
| # display_matrix_result(path_matrix, padded_matching, padded_prediction, padded_target) | |
| # Initial padding should not reduce score | |
| score = path_matrix[padded_matching[-1]] | |
| return short_matching, score.item() | |
| def score_correct(matching, prediction, target, threshold): | |
| """Count the number of correct phonemes in the target""" | |
| # Now from the matching count errors | |
| insertions = deletions = substitutions = 0 | |
| for i, match in enumerate(matching[1:]): | |
| if np.all(match - matching[i] == [0, 1]): | |
| # Deletion occurred | |
| deletions += 1 | |
| elif np.all(match - matching[i] == [1, 0]): | |
| # Insertion | |
| insertions += 1 | |
| else: | |
| # Match probability, 1 == good match | |
| # plot_metric(argmax_selection, reduced_logits[0, match[0]], target[0, match[1]]) | |
| match_value = 1 - argmax_selection(prediction[0, match[0]], target[0, match[1]]) | |
| if match_value < threshold: | |
| substitutions += 1 | |
| return max(0, target.shape[1] - insertions - deletions - substitutions) | |
| def score_phoneme_deletion(matching, prediction, target, threshold): | |
| # Now from the matching count errors | |
| insertions = deletions = substitutions = 0 | |
| for i, match in enumerate(matching[1:]): | |
| if np.all(match - matching[i] == [0, 1]): | |
| # Deletion occurred | |
| deletions += 1 | |
| elif np.all(match - matching[i] == [1, 0]): | |
| # Insertion | |
| insertions += 1 | |
| else: | |
| # Match probability, 1 == good match | |
| # plot_metric(argmax_selection, reduced_logits[0, match[0]], target[0, match[1]]) | |
| match_value = 1 - argmax_selection(prediction[0, match[0]], target[0, match[1]]) | |
| if match_value < threshold: | |
| substitutions += 1 | |
| # First phoneme should NOT match | |
| if 0 in matching[0]: | |
| indices = np.argwhere(matching[:, 0] == 0).flatten() | |
| for i in indices: | |
| match_value = 1 - argmax_selection( | |
| prediction[0, matching[i, 0]], | |
| target[0, matching[i, 1]] | |
| ) | |
| if match_value > threshold: | |
| return 0 | |
| if insertions + deletions + substitutions == 0: | |
| return 2 | |
| if insertions + deletions + substitutions == 1: | |
| return 1 | |
| return 0 | |
| def remove_pad_tokens(prediction, pad_token_id, temperature): | |
| """ | |
| Remove the pad token from a prediction to decrease temporal effects. | |
| :param prediction: Predicted logits. | |
| :param int pad_token_id: ID of the pad token. | |
| :param float temperature: Temperature to pass to the SoftMax. | |
| :return torch.Tensor: Probabilities where no row has a pad token id as an argmax. | |
| """ | |
| logits = torch.softmax( | |
| torch.as_tensor(prediction) / temperature, | |
| dim=-1 | |
| ) | |
| reduced_logits = logits[torch.argmax(logits, -1) != pad_token_id] | |
| reduced_logits = reduced_logits.reshape((1, reduced_logits.shape[0], reduced_logits.shape[1])) | |
| return reduced_logits | |
| def get_alignment_score( | |
| prediction, | |
| target, | |
| weights, | |
| pad_token_id=58, | |
| scoring=common.Scoring.NUMBER_CORRECT | |
| ): | |
| """ | |
| Get a classification score, either 0, 1 or 2 from a prediction and a target. | |
| Both the prediction and the target should be logits. | |
| The result depends on the type of scoring. | |
| :param prediction: The output of the model, without activation function. | |
| :param target: The logits we have to match. | |
| :param weights: A sequence of weights to apply. | |
| :param int pad_token_id: Index of elements in the sequence that should be ignored. | |
| :param common.Scoring scoring: Type of scoring to use | |
| :return int: Scoring score. | |
| """ | |
| collapsed_prediction = remove_pad_tokens(prediction, pad_token_id, weights[3]) | |
| matching, alignment_score = bellman_matching( | |
| collapsed_prediction, | |
| target, | |
| insertion_cost=weights[0], | |
| deletion_cost=weights[1], | |
| metric=l2_logit_norm | |
| ) | |
| np_matching = np.array(matching) | |
| if scoring is common.Scoring.NUMBER_CORRECT: | |
| return score_correct(np_matching, collapsed_prediction, target, weights[2]) | |
| if scoring is common.Scoring.PHONEME_DELETION: | |
| return score_phoneme_deletion(np_matching, collapsed_prediction, target, weights[2]) | |
| raise NotImplementedError("Unknown scoring method.") | |