wavlm-base-plus-phonemizer-fr-it / wavlm_phoneme_fr_it.py
Hugo Farajallah
feat(model): new model with transformers and 10.1% PER.
40f54ba
import torch
import transformers
_HIDDEN_STATES_START_POSITION = 2
class WavLMPhonemeFrIt(transformers.WavLMForCTC):
"""
PhonemeRecognizer: WavLM + Linear layer for speech recognition.
It natively separates French and Italian.
For a more professional implementation, view
https://github.com/huggingface/transformers/blob/main/src/transformers/models/wav2vec2/modeling_wav2vec2.py
"""
def __init__(self, config):
"""
Create the new model out of a combination of both models.
:param config: Model config.
:param lm_head.LMHead lm_head: Language Model head to perform features-to-phonemes.
"""
super().__init__(config)
output_hidden_size = (
config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
)
# Replace head and add multilingualism
self.lm_head = torch.nn.Linear(output_hidden_size + 1, config.vocab_size)
def language_classifer(self, language):
"""
Return a float identifying each known language.
"fr" has value of 0, "it" a value of one.
Other languages will have a value increasing in lexicographic order.
:param str language: Language to identify, should be two letters.
:return float: Unique identifier, between 0 and 1.
"""
if language == "fr":
return 0
if language == "it":
return 1
# Some random code to encode a two-letter language between 0 and 1
# "aa" should be 0+1=1 and "zz" should be 1+2=3
codes = (
(ord(letter) - ord("a")) / (ord("z") - ord("a")) + i
for i, letter in enumerate(language)
)
# Transform to [0, 1]
return (sum(codes) - 1) / 2
def add_language_to_hidden(self, input_values, language):
input_batch = torch.empty(
(input_values.shape[0], input_values.shape[1], input_values.shape[2] + 1),
dtype=input_values.dtype,
device=input_values.device
)
input_batch[:, :, :-1] = input_values
if language is None:
input_batch[:, :, -1] = torch.zeros((input_values.shape[1], ))
else:
if isinstance(language, str):
lang_val = torch.full((input_values.shape[1], ), self.language_classifer(language))
else:
lang_val = (
torch
.tensor([[self.language_classifer(lang)] for lang in language])
.repeat((1, input_batch.shape[1]))
)
input_batch[:, :, -1] = lang_val
return input_batch
def forward(
self,
input_values: torch.Tensor,
attention_mask: torch.Tensor = None,
language=None,
output_attentions: bool = None,
output_hidden_states: bool = None,
return_dict: bool = None,
labels: torch.Tensor = None,
):
"""
Classify audio to a chain of phonemes of the same length.
Stolen from
https://github.com/huggingface/transformers/blob/6ba8a1ff4550b4450a22a0b0d907312955ce0fd5/src/transformers/models/wavlm/modeling_wavlm.py#L1196
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None and labels.max() >= self.config.vocab_size:
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
outputs = self.wavlm(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states)
hidden_with_lang = self.add_language_to_hidden(hidden_states, language)
logits = self.lm_head(hidden_with_lang)
loss = None
if labels is not None:
# retrieve loss input_lengths from attention_mask
attention_mask = (
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
)
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
# assuming that padded tokens are filled with -100
# when not being attended to
labels_mask = labels >= 0
target_lengths = labels_mask.sum(-1)
flattened_targets = labels.masked_select(labels_mask)
# ctc_loss doesn't support fp16
log_probs = torch.nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
with torch.backends.cudnn.flags(enabled=False):
loss = torch.nn.functional.ctc_loss(
log_probs,
flattened_targets,
input_lengths,
target_lengths,
blank=self.config.pad_token_id,
reduction=self.config.ctc_loss_reduction,
zero_infinity=self.config.ctc_zero_infinity,
)
if not return_dict:
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return ((loss,) + output) if loss is not None else output
return transformers.modeling_outputs.CausalLMOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
)
def freeze_feature_encoder_only(self):
# Unfreeze base model
for param in self.wavlm.parameters():
param.requires_grad = True
# Now freeze the first layer
self.freeze_feature_encoder()
def freeze_layer(layer, freeze=True):
for param in layer.parameters():
param.requires_grad = not freeze
layer._requires_grad = not freeze
def get_wavlm_phoneme_fr_it(tokenizer, freeze_hidden_layers=False):
model = WavLMPhonemeFrIt.from_pretrained(
"microsoft/wavlm-base-plus",
ctc_loss_reduction="mean",
pad_token_id=tokenizer.pad_token_id,
vocab_size=len(tokenizer)
)
if freeze_hidden_layers:
model.freeze_base_model()
return model