Hugo Farajallah
fix(README): typo in demo code.
ea4351a
"""Just a demo code to use the model."""
import torch
import torchaudio
import transformers
import wavlm_phoneme_fr_it
# Prepare the input data
SAMPLING_RATE = 16_000
audio_files = [
{
"path": "audio-samples/tsenkher-fr.wav",
"language": "fr",
"text": "Sa capitale est Tsenkher."
},
{
"path": "audio-samples/italiens-fr.wav",
"language": "fr",
"text": "Les Italiens ont été les premiers à réagir."
},
{
"path": "audio-samples/entrato-it.wav",
"language": "it",
"text": "Ma nessuno può esservi entrato!"
}
]
audio_arrays = []
for audio in audio_files:
audio_array, frequency = torchaudio.load(audio["path"])
if frequency != SAMPLING_RATE:
raise ValueError(
f"Input audio frequency should be {SAMPLING_RATE} Hz, it it {frequency} Hz."
)
audio_arrays.append(audio_array[0].numpy())
# Load the CTC processor
feature_extractor = transformers.AutoFeatureExtractor.from_pretrained(
"microsoft/wavlm-base-plus"
)
tokenizer = transformers.Wav2Vec2CTCTokenizer(
"./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|"
)
processor = transformers.Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
inputs = processor(
audio_arrays,
sampling_rate=SAMPLING_RATE,
padding=True,
return_tensors="pt",
)
inputs["language"] = [row["language"] for row in audio_files] # "fr" or "it"
# Model with weights
model = wavlm_phoneme_fr_it.WavLMPhonemeFrIt.from_pretrained(
"hugofara/wavlm-base-plus-phonemizer-fr-it"
)
# Do inference
with torch.no_grad():
logits = model(**inputs).logits
# Simple ArgMax for demonstration
label_ids = torch.argmax(logits, -1)
predictions = processor.batch_decode(label_ids)
column_length = 34
print(
"Input file".center(column_length),
"Predicted phonemes".center(column_length),
"Original text".center(column_length),
sep=" | "
)
for file, prediction in zip(audio_files, predictions):
print(
file["path"].center(column_length),
"".join(prediction).center(column_length),
file["text"],
sep=" | "
)