|
|
import torch |
|
|
import transformers |
|
|
|
|
|
|
|
|
_HIDDEN_STATES_START_POSITION = 2 |
|
|
|
|
|
class WavLMPhonemeFrIt(transformers.WavLMForCTC): |
|
|
""" |
|
|
PhonemeRecognizer: WavLM + Linear layer for speech recognition. |
|
|
|
|
|
It natively separates French and Italian. |
|
|
|
|
|
For a more professional implementation, view |
|
|
https://github.com/huggingface/transformers/blob/main/src/transformers/models/wav2vec2/modeling_wav2vec2.py |
|
|
""" |
|
|
|
|
|
def __init__(self, config): |
|
|
""" |
|
|
Create the new model out of a combination of both models. |
|
|
|
|
|
:param config: Model config. |
|
|
:param lm_head.LMHead lm_head: Language Model head to perform features-to-phonemes. |
|
|
""" |
|
|
super().__init__(config) |
|
|
output_hidden_size = ( |
|
|
config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size |
|
|
) |
|
|
|
|
|
self.lm_head = torch.nn.Linear(output_hidden_size + 1, config.vocab_size) |
|
|
|
|
|
def language_classifer(self, language): |
|
|
""" |
|
|
Return a float identifying each known language. |
|
|
|
|
|
"fr" has value of 0, "it" a value of one. |
|
|
Other languages will have a value increasing in lexicographic order. |
|
|
|
|
|
:param str language: Language to identify, should be two letters. |
|
|
:return float: Unique identifier, between 0 and 1. |
|
|
""" |
|
|
if language == "fr": |
|
|
return 0 |
|
|
if language == "it": |
|
|
return 1 |
|
|
|
|
|
|
|
|
|
|
|
codes = ( |
|
|
(ord(letter) - ord("a")) / (ord("z") - ord("a")) + i |
|
|
for i, letter in enumerate(language) |
|
|
) |
|
|
|
|
|
return (sum(codes) - 1) / 2 |
|
|
|
|
|
def add_language_to_hidden(self, input_values, language): |
|
|
input_batch = torch.empty( |
|
|
(input_values.shape[0], input_values.shape[1], input_values.shape[2] + 1), |
|
|
dtype=input_values.dtype, |
|
|
device=input_values.device |
|
|
) |
|
|
|
|
|
input_batch[:, :, :-1] = input_values |
|
|
if language is None: |
|
|
input_batch[:, :, -1] = torch.zeros((input_values.shape[1], )) |
|
|
else: |
|
|
if isinstance(language, str): |
|
|
lang_val = torch.full((input_values.shape[1], ), self.language_classifer(language)) |
|
|
else: |
|
|
lang_val = ( |
|
|
torch |
|
|
.tensor([[self.language_classifer(lang)] for lang in language]) |
|
|
.repeat((1, input_batch.shape[1])) |
|
|
) |
|
|
input_batch[:, :, -1] = lang_val |
|
|
return input_batch |
|
|
|
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_values: torch.Tensor, |
|
|
attention_mask: torch.Tensor = None, |
|
|
language=None, |
|
|
output_attentions: bool = None, |
|
|
output_hidden_states: bool = None, |
|
|
return_dict: bool = None, |
|
|
labels: torch.Tensor = None, |
|
|
): |
|
|
""" |
|
|
Classify audio to a chain of phonemes of the same length. |
|
|
|
|
|
Stolen from |
|
|
https://github.com/huggingface/transformers/blob/6ba8a1ff4550b4450a22a0b0d907312955ce0fd5/src/transformers/models/wavlm/modeling_wavlm.py#L1196 |
|
|
""" |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if labels is not None and labels.max() >= self.config.vocab_size: |
|
|
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") |
|
|
|
|
|
outputs = self.wavlm( |
|
|
input_values, |
|
|
attention_mask=attention_mask, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
hidden_states = outputs[0] |
|
|
hidden_states = self.dropout(hidden_states) |
|
|
|
|
|
hidden_with_lang = self.add_language_to_hidden(hidden_states, language) |
|
|
|
|
|
logits = self.lm_head(hidden_with_lang) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
|
|
|
attention_mask = ( |
|
|
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) |
|
|
) |
|
|
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) |
|
|
|
|
|
|
|
|
|
|
|
labels_mask = labels >= 0 |
|
|
target_lengths = labels_mask.sum(-1) |
|
|
flattened_targets = labels.masked_select(labels_mask) |
|
|
|
|
|
|
|
|
log_probs = torch.nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) |
|
|
|
|
|
with torch.backends.cudnn.flags(enabled=False): |
|
|
loss = torch.nn.functional.ctc_loss( |
|
|
log_probs, |
|
|
flattened_targets, |
|
|
input_lengths, |
|
|
target_lengths, |
|
|
blank=self.config.pad_token_id, |
|
|
reduction=self.config.ctc_loss_reduction, |
|
|
zero_infinity=self.config.ctc_zero_infinity, |
|
|
) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
return transformers.modeling_outputs.CausalLMOutput( |
|
|
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions |
|
|
) |
|
|
|
|
|
def freeze_feature_encoder_only(self): |
|
|
|
|
|
for param in self.wavlm.parameters(): |
|
|
param.requires_grad = True |
|
|
|
|
|
self.freeze_feature_encoder() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def freeze_layer(layer, freeze=True): |
|
|
for param in layer.parameters(): |
|
|
param.requires_grad = not freeze |
|
|
layer._requires_grad = not freeze |
|
|
|
|
|
|
|
|
def get_wavlm_phoneme_fr_it(tokenizer, freeze_hidden_layers=False): |
|
|
model = WavLMPhonemeFrIt.from_pretrained( |
|
|
"microsoft/wavlm-base-plus", |
|
|
ctc_loss_reduction="mean", |
|
|
pad_token_id=tokenizer.pad_token_id, |
|
|
vocab_size=len(tokenizer) |
|
|
) |
|
|
if freeze_hidden_layers: |
|
|
model.freeze_base_model() |
|
|
return model |
|
|
|