|
|
"""Just a demo code to use the model.""" |
|
|
import torch |
|
|
import torchaudio |
|
|
import transformers |
|
|
|
|
|
import wavlm_phoneme_fr_it |
|
|
|
|
|
|
|
|
|
|
|
SAMPLING_RATE = 16_000 |
|
|
|
|
|
audio_files = [ |
|
|
{ |
|
|
"path": "audio-samples/tsenkher-fr.wav", |
|
|
"language": "fr", |
|
|
"text": "Sa capitale est Tsenkher." |
|
|
}, |
|
|
{ |
|
|
"path": "audio-samples/italiens-fr.wav", |
|
|
"language": "fr", |
|
|
"text": "Les Italiens ont été les premiers à réagir." |
|
|
}, |
|
|
{ |
|
|
"path": "audio-samples/entrato-it.wav", |
|
|
"language": "it", |
|
|
"text": "Ma nessuno può esservi entrato!" |
|
|
} |
|
|
] |
|
|
audio_arrays = [] |
|
|
for audio in audio_files: |
|
|
audio_array, frequency = torchaudio.load(audio["path"]) |
|
|
if frequency != SAMPLING_RATE: |
|
|
raise ValueError( |
|
|
f"Input audio frequency should be {SAMPLING_RATE} Hz, it it {frequency} Hz." |
|
|
) |
|
|
audio_arrays.append(audio_array[0].numpy()) |
|
|
|
|
|
|
|
|
|
|
|
feature_extractor = transformers.AutoFeatureExtractor.from_pretrained( |
|
|
"microsoft/wavlm-base-plus" |
|
|
) |
|
|
tokenizer = transformers.Wav2Vec2CTCTokenizer( |
|
|
"./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|" |
|
|
) |
|
|
processor = transformers.Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) |
|
|
|
|
|
inputs = processor( |
|
|
audio_arrays, |
|
|
sampling_rate=SAMPLING_RATE, |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
inputs["language"] = [row["language"] for row in audio_files] |
|
|
|
|
|
|
|
|
|
|
|
model = wavlm_phoneme_fr_it.WavLMPhonemeFrIt.from_pretrained( |
|
|
"hugofara/wavlm-base-plus-phonemizer-fr-it" |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = model(**inputs).logits |
|
|
|
|
|
|
|
|
label_ids = torch.argmax(logits, -1) |
|
|
predictions = processor.batch_decode(label_ids) |
|
|
|
|
|
|
|
|
column_length = 34 |
|
|
print( |
|
|
"Input file".center(column_length), |
|
|
"Predicted phonemes".center(column_length), |
|
|
"Original text".center(column_length), |
|
|
sep=" | " |
|
|
) |
|
|
for file, prediction in zip(audio_files, predictions): |
|
|
print( |
|
|
file["path"].center(column_length), |
|
|
"".join(prediction).center(column_length), |
|
|
file["text"], |
|
|
sep=" | " |
|
|
) |
|
|
|