Hugo Farajallah
chore(dataset_process): renames it to "aligner" as it better describes its role.
b57fa93
import matplotlib.pyplot as plt
import numpy as np
import torch
import transformers
import common
def encode_phonemes(
phonemes: list[str] | str,
tokenizer: transformers.Wav2Vec2Tokenizer
) -> torch.Tensor:
"""
From list of phonemes to a logits-like matrix.
:param list[str] | str phonemes: List of individuals phonemes to encode
:param tokenizer: The tokenizer to use.
:return torch.Tensor: Encodings for the phonemes, size (1, len(phonemes), n_tokens)
"""
encodings = torch.zeros(
(1, len(phonemes), len(tokenizer.encoder) + 2),
dtype=torch.uint8
)
label_ids = tokenizer.encode(phonemes)
for i, label_id in enumerate(label_ids):
encodings[0, i, label_id] = 1
return encodings
def l2_logit_norm(prediction, target):
"""
Apply L2 distance between two vectors.
Is close to 0 for two similar vectors, close to 1 for different vectors
"""
val = torch.norm(prediction - target) / 1.414
return val
def cosine_similarity(prediction, target):
normed = torch.softmax(prediction, 0)
similarity = (1 - torch.nn.CosineSimilarity(dim=0)(normed, target)) / 2
return similarity # * torch.norm(prediction)
def argmax_selection(prediction, target):
"""
Select normalized(prediction)[argmax(target)]
0 for same vectors, 1 for totally different
"""
return 1 - prediction[torch.argmax(target)]
def plot_metric(metric, prediction, target):
"""Plot the result of a metric."""
fig, ax = plt.subplots()
_model, processor = common.get_model()
predicted_labels = (
processor.decode(i) if i < prediction.shape[0] - 4 else "" for i in range(prediction.shape[0])
)
normed = torch.softmax(prediction, 0)
ax.plot(normed, label="Normed prediction")
ax.scatter([torch.argmax(target).item()], [1], label="Target", marker="X")
ax.set_xticks(range(prediction.shape[0]), predicted_labels)
value = 1 - metric(prediction, target)
ax.plot([0, normed.shape[0]], [value, value], label="1 - Metric value (1 = perfect)")
plt.legend()
plt.show()
def compute_path_matrix(prediction, target, metric, insertion_cost, deletion_cost):
"""Compute the alignment matrix of two matrices."""
# Define the matrix
path_matrix = torch.empty((prediction.shape[1], target.shape[1]))
# Now run recursively
for i, pred_column in enumerate(prediction[0]):
for j, target_column in enumerate(target[0]):
if i == 0 and j == 0:
path_matrix[i, j] = 0
elif i == 0:
path_matrix[0, j] = j * insertion_cost
elif j == 0:
path_matrix[i, 0] = i * deletion_cost
else:
# plot_metric(metric, pred_column, target_column)
path_matrix[i, j] = min(
path_matrix[i - 1, j - 1] + metric(pred_column, target_column),
path_matrix[i - 1, j] + deletion_cost,
path_matrix[i, j - 1] + insertion_cost
)
return path_matrix
def solve_path(prediction, target, path_matrix):
"""
Find the matching path between a prediction, a target and a path matrix.
For each step we minimize the cost.
"""
line, col = prediction.shape[1] - 1, target.shape[1] - 1
matching = []
while line > 0 or col > 0:
matching.append((line, col))
directions = []
if line > 0 and col > 0:
directions.append((line - 1, col - 1))
if line > 0:
directions.append((line - 1, col))
if col > 0:
directions.append((line, col - 1))
best_score = float("inf")
dir_index = -1
for i, direction in enumerate(directions):
if path_matrix[direction[0]][direction[1]] < best_score:
best_score = path_matrix[direction[0]][direction[1]]
dir_index = i
line, col = directions[dir_index]
matching.reverse()
return matching
def display_matrix_result(path_matrix, matching, prediction, target, processor=None):
"""Display all the information resulting from a Bellman matching of matrices.
Returns the figure instead of showing it directly for use in Gradio.
"""
fig, axis = plt.subplots(figsize=(12, 8))
if processor is None:
_model, processor = common.get_model()
# Display the matrix
im = axis.matshow(path_matrix.T, aspect="auto", cmap='Blues')
cbar = plt.colorbar(im, ax=axis)
cbar.set_label('Alignment Cost', rotation=270, labelpad=20, fontsize=11)
# Set the labels for the axes with clearer names
axis.set_xlabel('Predicted Phoneme Sequence', fontsize=12)
axis.set_ylabel('Target Phoneme Sequence', fontsize=12)
axis.set_title('Phoneme Alignment Matrix\n(Blue = Lower Cost, Red Line = Optimal Path)',
fontsize=14, pad=20)
# Get phoneme labels for both axes
predicted_labels = tuple(map(processor.decode, torch.argmax(prediction, -1)[0]))
target_labels = tuple(map(processor.decode, torch.argmax(target, -1)[0]))
# Set x-axis ticks (predicted phonemes)
non_empty_pred_indices = [i for i, label in enumerate(predicted_labels) if label not in ("", "[PAD]")]
non_empty_pred_labels = [label for i, label in enumerate(predicted_labels) if label not in ("", "[PAD]")]
if non_empty_pred_indices:
axis.set_xticks(non_empty_pred_indices)
axis.set_xticklabels(non_empty_pred_labels, rotation=45, ha='right', fontsize=10)
# Set y-axis ticks (target phonemes)
non_empty_target_indices = [i for i, label in enumerate(target_labels) if label not in ("", "[PAD]")]
non_empty_target_labels = [label for i, label in enumerate(target_labels) if label not in ("", "[PAD]")]
if non_empty_target_indices:
axis.set_yticks(non_empty_target_indices)
axis.set_yticklabels(non_empty_target_labels, fontsize=10)
# Add subtle grid
axis.grid(which="major", color="gray", alpha=0.2, linestyle="-")
# Plot the optimal path in red with better visibility
if matching:
axis.plot(
[val[0] for val in matching],
[val[1] for val in matching],
color="red",
linewidth=3,
marker='o',
markersize=4,
markerfacecolor='white',
markeredgecolor='red',
markeredgewidth=2,
label="Optimal Alignment Path",
alpha=0.9
)
# Add legend with better positioning
axis.legend(loc='upper right', bbox_to_anchor=(1.0, 1.0), fontsize=11)
# Add text annotations for better understanding
axis.text(
0.02, 0.98, 'Lower values indicate\nbetter alignment',
transform=axis.transAxes, fontsize=9, va='top', ha='left',
bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8)
)
plt.tight_layout()
return fig
def bellman_matching(prediction, target, insertion_cost=1.3, deletion_cost=3, metric=l2_logit_norm):
"""
Match to sequences with Bellman's algorithm.
:param prediction: Actual prediction
:param target: Target list of values.
:param float insertion_cost: Something was added in prediction.
:param float deletion_cost: Something was missing in prediction.
:param Callable metric: The metric to use.
:return tuple(list, float): Best alignment [(prediction[i], target[j]), ...] for all elements, and its score
"""
# Add padding: start matching on letters (do not penalize kids starting with insertions or long audio)
padded_target = torch.zeros((target.shape[0], target.shape[1] + 1, target.shape[2]))
padded_target[0, 1:] = target
padded_prediction = torch.zeros((prediction.shape[0], prediction.shape[1] + 1, prediction.shape[2]))
padded_prediction[0, 1:] = prediction
path_matrix = compute_path_matrix(
padded_prediction, padded_target, metric,
insertion_cost,
deletion_cost
)
# Now solve path, find candidate diagonal
padded_matching = solve_path(padded_prediction, padded_target, path_matrix)
short_matching = []
for match in padded_matching:
if match[0] == 0 or match[1] == 0:
continue
short_matching.append((match[0] - 1, match[1] - 1))
if match[1] == padded_target.shape[1] - 1:
break
# display_matrix_result(path_matrix, padded_matching, padded_prediction, padded_target)
# Initial padding should not reduce score
score = path_matrix[padded_matching[-1]]
return short_matching, score.item()
def score_correct(matching, prediction, target, threshold):
"""Count the number of correct phonemes in the target"""
# Now from the matching count errors
insertions = deletions = substitutions = 0
for i, match in enumerate(matching[1:]):
if np.all(match - matching[i] == [0, 1]):
# Deletion occurred
deletions += 1
elif np.all(match - matching[i] == [1, 0]):
# Insertion
insertions += 1
else:
# Match probability, 1 == good match
# plot_metric(argmax_selection, reduced_logits[0, match[0]], target[0, match[1]])
match_value = 1 - argmax_selection(prediction[0, match[0]], target[0, match[1]])
if match_value < threshold:
substitutions += 1
return max(0, target.shape[1] - insertions - deletions - substitutions)
def score_phoneme_deletion(matching, prediction, target, threshold):
# Now from the matching count errors
insertions = deletions = substitutions = 0
for i, match in enumerate(matching[1:]):
if np.all(match - matching[i] == [0, 1]):
# Deletion occurred
deletions += 1
elif np.all(match - matching[i] == [1, 0]):
# Insertion
insertions += 1
else:
# Match probability, 1 == good match
# plot_metric(argmax_selection, reduced_logits[0, match[0]], target[0, match[1]])
match_value = 1 - argmax_selection(prediction[0, match[0]], target[0, match[1]])
if match_value < threshold:
substitutions += 1
# First phoneme should NOT match
if 0 in matching[0]:
indices = np.argwhere(matching[:, 0] == 0).flatten()
for i in indices:
match_value = 1 - argmax_selection(
prediction[0, matching[i, 0]],
target[0, matching[i, 1]]
)
if match_value > threshold:
return 0
if insertions + deletions + substitutions == 0:
return 2
if insertions + deletions + substitutions == 1:
return 1
return 0
def remove_pad_tokens(prediction, pad_token_id, temperature):
"""
Remove the pad token from a prediction to decrease temporal effects.
:param prediction: Predicted logits.
:param int pad_token_id: ID of the pad token.
:param float temperature: Temperature to pass to the SoftMax.
:return torch.Tensor: Probabilities where no row has a pad token id as an argmax.
"""
logits = torch.softmax(
torch.as_tensor(prediction) / temperature,
dim=-1
)
reduced_logits = logits[torch.argmax(logits, -1) != pad_token_id]
reduced_logits = reduced_logits.reshape((1, reduced_logits.shape[0], reduced_logits.shape[1]))
return reduced_logits
def get_alignment_score(
prediction,
target,
weights,
pad_token_id=58,
scoring=common.Scoring.NUMBER_CORRECT
):
"""
Get a classification score, either 0, 1 or 2 from a prediction and a target.
Both the prediction and the target should be logits.
The result depends on the type of scoring.
:param prediction: The output of the model, without activation function.
:param target: The logits we have to match.
:param weights: A sequence of weights to apply.
:param int pad_token_id: Index of elements in the sequence that should be ignored.
:param common.Scoring scoring: Type of scoring to use
:return int: Scoring score.
"""
collapsed_prediction = remove_pad_tokens(prediction, pad_token_id, weights[3])
matching, alignment_score = bellman_matching(
collapsed_prediction,
target,
insertion_cost=weights[0],
deletion_cost=weights[1],
metric=l2_logit_norm
)
np_matching = np.array(matching)
if scoring is common.Scoring.NUMBER_CORRECT:
return score_correct(np_matching, collapsed_prediction, target, weights[2])
if scoring is common.Scoring.PHONEME_DELETION:
return score_phoneme_deletion(np_matching, collapsed_prediction, target, weights[2])
raise NotImplementedError("Unknown scoring method.")