Source code for montreal_forced_aligner.dictionary.remapper

"""Classes for remapping a dictionary from one phone set to another"""
from __future__ import annotations

import itertools
import logging
import os
from pathlib import Path

import yaml

from montreal_forced_aligner.abc import TopLevelMfaWorker
from montreal_forced_aligner.dictionary.multispeaker import MultispeakerDictionaryMixin
from montreal_forced_aligner.exceptions import RemapAcousticMismatchError
from montreal_forced_aligner.helper import format_correction, format_probability, mfa_open
from montreal_forced_aligner.models import AcousticModel

logger = logging.getLogger("mfa")

__all__ = ["DictionaryRemapper"]


[docs] class DictionaryRemapper(MultispeakerDictionaryMixin, TopLevelMfaWorker): def __init__( self, acoustic_model_path: Path, phone_mapping_path: Path, **kwargs, ): self._data_source = kwargs["dictionary_path"].stem super().__init__(**kwargs) self.acoustic_model = AcousticModel(acoustic_model_path) self.phone_mapping_path = phone_mapping_path self.phone_remapping = {} @property def data_source_identifier(self) -> str: """Dictionary name""" return self._data_source @property def data_directory(self) -> Path: """Data directory for trainer""" return self.working_directory
[docs] def setup(self) -> None: """Setup for dictionary remapping""" super().setup() self.load_mapping() self.validate_mapping() if self.initialized: return self.dictionary_setup() os.makedirs(self.phones_dir, exist_ok=True) self.initialized = True
def load_mapping(self): with mfa_open(self.phone_mapping_path, "r") as f: self.phone_remapping = yaml.load(f, Loader=yaml.Loader) for key, values in self.phone_remapping.items(): if not isinstance(values, list): self.phone_remapping[key] = [values] def validate_mapping(self): unknown_phones = set() for key, values in self.phone_remapping.items(): for value in values: for p in value.split(): if p not in self.acoustic_model.meta["phones"]: unknown_phones.add(p) if unknown_phones: raise RemapAcousticMismatchError(unknown_phones, self.phone_mapping_path) def remap(self, output_dictionary_path: Path): self.setup() new_dictionary = {} skip_count = 0 extra_prob_keys = [ "silence_after_probability", "silence_before_correction", "non_silence_before_correction", ] for data in self.words_for_export(probability=True): phones = data["pronunciation"] w = data["word"] pron = phones.split() skip = False new_pron = [] for p in pron: if p not in self.phone_remapping: if p in self.acoustic_model.meta["phones"]: new_p = p else: skip = True else: new_p = self.phone_remapping[p] if skip: break if not isinstance(new_p, list): new_p = [new_p] new_pron.append(new_p) if skip: logger.debug(f"Skipping {w}: {' '.join(pron)}") skip_count += 1 continue if w not in new_dictionary: new_dictionary[w] = {} pron_combinations = list(itertools.product(*new_pron)) for new_pron in pron_combinations: pron_string = " ".join(new_pron) if pron_string not in new_dictionary[w]: new_dictionary[w][pron_string] = { "count": 1, "probability": data["probability"], "silence_after_probability": data["silence_after_probability"], "silence_before_correction": data["silence_before_correction"], "non_silence_before_correction": data["non_silence_before_correction"], } else: new_dictionary[w][pron_string]["count"] += 1 if data["probability"] is not None: if new_dictionary[w][pron_string]["probability"] is None: new_dictionary[w][pron_string]["probability"] = data["probability"] else: new_dictionary[w][pron_string]["probability"] = max( data["probability"], new_dictionary[w][pron_string]["probability"] ) for k in extra_prob_keys: if data[k] is not None: if new_dictionary[w][pron_string][k] is None: new_dictionary[w][pron_string][k] = data[k] else: new_dictionary[w][pron_string][k] += data[k] logger.info(f"Skipped {skip_count} pronunciations for having unmapped phones") with mfa_open(output_dictionary_path, "w") as f: for w, prons in sorted(new_dictionary.items(), key=lambda x: x[0]): for pron, data in sorted(prons.items(), key=lambda x: x[0]): probability_string = "" if data["probability"] is not None: probability_string = f"\t{format_probability(data['probability'])}" extra_probs = [ data["silence_after_probability"], data["silence_before_correction"], data["non_silence_before_correction"], ] if all(x is None for x in extra_probs): continue for i, x in enumerate(extra_probs): if x is None: continue mean_value = x / data["count"] if i == 0: mean_value = format_correction(mean_value) else: mean_value = format_correction(mean_value, positive_only=False) probability_string += f"\t{mean_value}" f.write(f"{w}{probability_string}\t{pron}\n") logger.info(f"Wrote remapped dictionary to {output_dictionary_path}")