File size: 2,195 Bytes
f6be2ea
 
 
 
 
40f54ba
f6be2ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40f54ba
 
 
 
 
 
 
 
 
 
 
f6be2ea
 
 
 
 
 
 
40f54ba
 
 
 
 
f6be2ea
 
40f54ba
 
 
ea4351a
 
40f54ba
f6be2ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""Just a demo code to use the model."""
import torch
import torchaudio
import transformers

import wavlm_phoneme_fr_it


# Prepare the input data
SAMPLING_RATE = 16_000

audio_files = [
    {
        "path": "audio-samples/tsenkher-fr.wav",
        "language": "fr",
        "text": "Sa capitale est Tsenkher."
    },
    {
        "path": "audio-samples/italiens-fr.wav",
        "language": "fr",
        "text": "Les Italiens ont été les premiers à réagir."
    },
    {
        "path": "audio-samples/entrato-it.wav",
        "language": "it",
        "text": "Ma nessuno può esservi entrato!"
    }
]
audio_arrays = []
for audio in audio_files:
    audio_array, frequency = torchaudio.load(audio["path"])
    if frequency != SAMPLING_RATE:
        raise ValueError(
            f"Input audio frequency should be {SAMPLING_RATE} Hz, it it {frequency} Hz."
        )
    audio_arrays.append(audio_array[0].numpy())


# Load the CTC processor
feature_extractor = transformers.AutoFeatureExtractor.from_pretrained(
    "microsoft/wavlm-base-plus"
)
tokenizer = transformers.Wav2Vec2CTCTokenizer(
    "./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|"
)
processor = transformers.Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

inputs = processor(
    audio_arrays,
    sampling_rate=SAMPLING_RATE,
    padding=True,
    return_tensors="pt",
)
inputs["language"] = [row["language"] for row in audio_files]  # "fr" or "it"


# Model with weights
model = wavlm_phoneme_fr_it.WavLMPhonemeFrIt.from_pretrained(
    "hugofara/wavlm-base-plus-phonemizer-fr-it"
)
# Do inference
with torch.no_grad():
    logits = model(**inputs).logits

# Simple ArgMax for demonstration
label_ids = torch.argmax(logits, -1) 
predictions = processor.batch_decode(label_ids)


column_length = 34
print(
    "Input file".center(column_length),
    "Predicted phonemes".center(column_length),
    "Original text".center(column_length),
    sep=" | "
)
for file, prediction in zip(audio_files, predictions):
    print(
        file["path"].center(column_length),
        "".join(prediction).center(column_length),
        file["text"],
        sep=" | "
    )