File size: 2,195 Bytes
f6be2ea 40f54ba f6be2ea 40f54ba f6be2ea 40f54ba f6be2ea 40f54ba ea4351a 40f54ba f6be2ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
"""Just a demo code to use the model."""
import torch
import torchaudio
import transformers
import wavlm_phoneme_fr_it
# Prepare the input data
SAMPLING_RATE = 16_000
audio_files = [
{
"path": "audio-samples/tsenkher-fr.wav",
"language": "fr",
"text": "Sa capitale est Tsenkher."
},
{
"path": "audio-samples/italiens-fr.wav",
"language": "fr",
"text": "Les Italiens ont été les premiers à réagir."
},
{
"path": "audio-samples/entrato-it.wav",
"language": "it",
"text": "Ma nessuno può esservi entrato!"
}
]
audio_arrays = []
for audio in audio_files:
audio_array, frequency = torchaudio.load(audio["path"])
if frequency != SAMPLING_RATE:
raise ValueError(
f"Input audio frequency should be {SAMPLING_RATE} Hz, it it {frequency} Hz."
)
audio_arrays.append(audio_array[0].numpy())
# Load the CTC processor
feature_extractor = transformers.AutoFeatureExtractor.from_pretrained(
"microsoft/wavlm-base-plus"
)
tokenizer = transformers.Wav2Vec2CTCTokenizer(
"./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|"
)
processor = transformers.Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
inputs = processor(
audio_arrays,
sampling_rate=SAMPLING_RATE,
padding=True,
return_tensors="pt",
)
inputs["language"] = [row["language"] for row in audio_files] # "fr" or "it"
# Model with weights
model = wavlm_phoneme_fr_it.WavLMPhonemeFrIt.from_pretrained(
"hugofara/wavlm-base-plus-phonemizer-fr-it"
)
# Do inference
with torch.no_grad():
logits = model(**inputs).logits
# Simple ArgMax for demonstration
label_ids = torch.argmax(logits, -1)
predictions = processor.batch_decode(label_ids)
column_length = 34
print(
"Input file".center(column_length),
"Predicted phonemes".center(column_length),
"Original text".center(column_length),
sep=" | "
)
for file, prediction in zip(audio_files, predictions):
print(
file["path"].center(column_length),
"".join(prediction).center(column_length),
file["text"],
sep=" | "
)
|