Source code for montreal_forced_aligner.g2p.trainer

"""Class definitions for training G2P models"""
from __future__ import annotations

import itertools
import logging
import operator
import os
import queue
import random
import re
import shutil
import subprocess
import threading
import time
import typing
from pathlib import Path
from queue import Queue
from typing import Any, List, NamedTuple, Set

import pynini
import pywrapfst
from pynini import Fst
from import tqdm

from montreal_forced_aligner import config
from import MetaDict, MfaWorker, TopLevelMfaWorker, TrainerMixin
from import WordType, WorkflowType
from montreal_forced_aligner.db import Pronunciation, Word
from montreal_forced_aligner.dictionary.multispeaker import MultispeakerDictionaryMixin
from montreal_forced_aligner.exceptions import KaldiProcessingError, PyniniAlignmentError
from montreal_forced_aligner.g2p.generator import PyniniValidator
from montreal_forced_aligner.helper import mfa_open
from montreal_forced_aligner.models import G2PModel
from montreal_forced_aligner.utils import thirdparty_binary

Labels = List[Any]

TOKEN_TYPES = ["byte", "utf8"]
INF = float("inf")
RAND_MAX = 32767

__all__ = ["RandomStartWorker", "PyniniTrainer", "PyniniTrainerMixin", "G2PTrainer"]

logger = logging.getLogger("mfa")

[docs] class RandomStart(NamedTuple): """Parameters for random starts""" idx: int seed: int input_far_path: Path output_far_path: Path cg_path: Path tempdir: Path train_opts: List[str]
def _get_far_labels(far_path: typing.Union[Path, str]) -> Set[int]: """Extracts label set from acceptors in a FAR. Args: far_path: :class:`~pathlib.Path` to FAR file. Returns: A set of integer labels found in the FAR. """ labels: Set[int] = set() reader = while not reader.done(): fst = reader.get_fst() assert, True) == pywrapfst.ACCEPTOR for state in fst.states(): labels.update(arc.ilabel for arc in fst.arcs(state)) next(reader) assert not reader.error() return labels
[docs] class RandomStartWorker(threading.Thread): """ Random start worker """ def __init__( self, job_name: int, job_q: Queue, return_queue: Queue, log_file: str, stopped: threading.Event, ): super().__init__() self.job_name = job_name self.job_q = job_q self.return_queue = return_queue self.log_file = log_file self.stopped = stopped self.finished = threading.Event()
[docs] def run(self) -> None: """Run the random start worker""" with mfa_open(self.log_file, "w") as log_file: while True: try: args = self.job_q.get(timeout=1) except queue.Empty: break if self.stopped.is_set(): continue try: start = time.time() # Randomize channel model. rfst_path = args.tempdir.joinpath(f"random-{args.seed:05d}.fst") afst_path = args.tempdir.joinpath(f"aligner-{args.seed:05d}.fst") likelihood_path = afst_path.with_suffix(".like") if not afst_path.exists(): cmd = [ thirdparty_binary("baumwelchrandomize"), f"--seed={args.seed}", str(args.cg_path), str(rfst_path), ] subprocess.check_call(cmd, stderr=log_file, env=os.environ) random_end = time.time() log_file.write( f"{args.seed} randomization took {random_end - start} seconds\n" ) # Train on randomized channel model. likelihood = INF cmd = [ thirdparty_binary("baumwelchtrain"), *args.train_opts, str(args.input_far_path), str(args.output_far_path), str(rfst_path), str(afst_path), ] log_file.write(f"{args.seed} train command: {' '.join(cmd)}\n") log_file.flush() with subprocess.Popen( cmd, stderr=subprocess.PIPE, text=True, env=os.environ ) as proc: # Parses STDERR to capture the likelihood. for line in proc.stderr: # type: ignore log_file.write(line) log_file.flush() line = line.rstrip() match = re.match(r"INFO: Iteration \d+.* (-?\d*(\.\d*)?)", line) if not match: continue assert match, line likelihood = float( self.return_queue.put(1) with mfa_open(likelihood_path, "w") as f: f.write(str(likelihood)) log_file.write( f"{args.seed} training took {time.time() - random_end:.3f} seconds\n" ) else: with mfa_open(likelihood_path, "r") as f: likelihood = self.return_queue.put((afst_path, likelihood)) except Exception: self.stopped.set() e = KaldiProcessingError([self.log_file]) e.job_name = self.job_name self.return_queue.put(e) self.finished.set() return
[docs] class G2PTrainer(MfaWorker, TrainerMixin): """ Abstract mixin class for G2P training Parameters ---------- validation_proportion: float Proportion of words to use as the validation set, defaults to 0.1, only used if ``evaluate`` is True num_pronunciations: int Number of pronunciations to generate evaluation_mode: bool Flag for whether to evaluate the model performance on an validation set See Also -------- :class:`` For base MFA parameters :class:`` For base trainer parameters Attributes ---------- g2p_training_dictionary: dict[str, list[str]] Dictionary of words to pronunciations to train from g2p_validation_dictionary: dict[str, list[str]] Dictionary of words to pronunciations to validate performance against g2p_graphemes: set[str] Set of graphemes in the training set """ def __init__( self, validation_proportion: float = 0.1, num_pronunciations: int = 0, evaluation_mode: bool = False, **kwargs, ): super().__init__(**kwargs) self.evaluation_mode = evaluation_mode self.validation_proportion = validation_proportion self.num_pronunciations = num_pronunciations self.g2p_training_dictionary = {} self.g2p_validation_dictionary = None self.g2p_training_graphemes = set() self.g2p_validation_graphemes = set() self.g2p_training_phones = set() self.g2p_validation_phones = set()
[docs] class PyniniTrainerMixin: """ Mixin for training Pynini G2P models Parameters ---------- order: int Order of the ngram model, defaults to 7 random_starts: int Number of random starts to use in initialization, defaults to 25 seed: int Seed for randomization, defaults to 1917 delta: float Comparison/quantization delta for Baum-Welch training, defaults to 1/1024 alpha: float Step size reduction power parameter for Baum-Welch training; full standard batch EM is run (not stepwise) if set to 0, defaults to 1.0 batch_size:int Batch size for Baum-Welch training, defaults to 200 num_iterations:int Maximum number of iterations to use in Baum-Welch training, defaults to 10 smoothing_method:str Smoothing method for the ngram model, defaults to "kneser_ney" pruning_method:str Pruning method for pruning the ngram model, defaults to "relative_entropy" model_size: int Target number of ngrams for pruning, defaults to 1000000 insertions: bool Flag for whether to allow for insertions, default True deletions: bool Flag for whether to allow for deletions, default True fst_default_cache_gc: str String to pass to OpenFst binaries for GC behavior fst_default_cache_gc_limit: str String to pass to OpenFst binaries for GC behavior """ def __init__( self, order: int = 8, random_starts: int = 25, delta: float = 1 / 1024, alpha: float = 1.0, batch_size: int = 800, num_iterations: int = 10, smoothing_method: str = "kneser_ney", pruning_method: str = "relative_entropy", model_size: int = 1000000, prune_threshold: float = 0.0000001, insertions: bool = True, deletions: bool = True, fst_default_cache_gc="", fst_default_cache_gc_limit="", **kwargs, ): super().__init__(**kwargs) if not hasattr(self, "_data_source"): self._data_source = None self.order = order self.random_starts = random_starts = delta self.alpha = alpha self.batch_size = batch_size self.num_iterations = num_iterations self.smoothing_method = smoothing_method self.pruning_method = pruning_method self.model_size = model_size self.prune_threshold = prune_threshold self.insertions = insertions self.deletions = deletions self.fst_default_cache_gc = fst_default_cache_gc self.fst_default_cache_gc_limit = fst_default_cache_gc_limit self._sym_path = None self._fst_path = None self.input_sym_path = None self.input_token_type = "utf8" self.output_token_type = "utf8" @property def data_source_identifier(self) -> str: """Dictionary name""" return self._data_source
[docs] def train_iteration(self) -> None: """Train iteration, not used""" pass
@property def architecture(self) -> str: """Pynini""" return "pynini" @property def input_far_path(self) -> Path: """Path to store grapheme archive""" return self.working_directory.joinpath(f"{self.data_source_identifier}.g.far") @property def output_far_path(self) -> Path: """Path to store phone archive""" return self.working_directory.joinpath(f"{self.data_source_identifier}.p.far") @property def cg_path(self) -> Path: """Path to covering grammar FST""" return self.working_directory.joinpath(f"{self.data_source_identifier}.cg.fst") @property def align_path(self) -> Path: """Path to store alignment models""" return self.working_directory.joinpath(f"{self.data_source_identifier}.align.fst") @property def afst_path(self) -> Path: """Path to store aligned FSTs""" return self.working_directory.joinpath(f"{self.data_source_identifier}.afst.far") @property def input_path(self) -> Path: """Path to temporary file to store grapheme training data""" return self.working_directory.joinpath(f"input_{self.data_source_identifier}.txt") @property def output_path(self) -> Path: """Path to temporary file to store phone training data""" return self.working_directory.joinpath(f"output_{self.data_source_identifier}.txt")
[docs] def generate_model(self) -> None: """ Generate an ngram G2P model from FAR strings """ assert os.path.exists(self.far_path) if os.path.exists(self.fst_path):"Model building already done, skipping!") return with mfa_open(self.working_log_directory.joinpath("model.log"), "w") as logf: ngramcount_proc = subprocess.Popen( [ thirdparty_binary("ngramcount"), "--require_symbols=false", f"--order={self.order}", self.far_path, ], stderr=logf, stdout=subprocess.PIPE, env=os.environ, ) ngrammake_proc = subprocess.Popen( [ thirdparty_binary("ngrammake"), f"--method={self.smoothing_method}", ], stdin=ngramcount_proc.stdout, stderr=logf, stdout=subprocess.PIPE, env=os.environ, ) command = [ thirdparty_binary("ngramshrink"), f"--method={self.pruning_method}", ] if self.model_size > 0: command.append(f"--target_number_of_ngrams={self.model_size}") else: command.append(f"--theta={self.prune_threshold}") ngramshrink_proc = subprocess.Popen( command + [ "-", self.far_path.with_suffix(".shrink"), ], stdin=ngrammake_proc.stdout, stdout=subprocess.PIPE, stderr=logf, env=os.environ, ) ngramshrink_proc.communicate() assert self.far_path.with_suffix(".shrink").exists() fstencode_proc = subprocess.Popen( [ thirdparty_binary("fstencode"), "--decode", self.far_path.with_suffix(".shrink"), self.encoder_path, self.far_path.with_suffix(".dec"), ], stderr=logf, env=os.environ, ) fstencode_proc.communicate() assert self.far_path.with_suffix(".dec").exists() self.far_path.with_suffix(".shrink").unlink() sort_proc = subprocess.Popen( [ thirdparty_binary("fstarcsort"), self.far_path.with_suffix(".dec"), self.fst_path, ], stderr=logf, env=os.environ, ) sort_proc.communicate() assert self.fst_path.exists() self.far_path.with_suffix(".dec").unlink()
@property def fst_path(self) -> Path: """Internal temporary FST file""" if self._fst_path is not None: return self._fst_path return self.working_directory.joinpath(f"{self.data_source_identifier}.fst") @property def far_path(self) -> Path: """Internal temporary FAR file""" return self.working_directory.joinpath(f"{self.data_source_identifier}.far") @property def encoder_path(self) -> Path: """Internal temporary encoder file""" return self.working_directory.joinpath(f"{self.data_source_identifier}.enc") @property def sym_path(self) -> Path: """Internal temporary symbol file""" if self._sym_path is not None: return self._sym_path return self.working_directory.joinpath("phones.txt")
[docs] def align_g2p(self) -> None: """Runs the entire alignment regimen.""" self._lexicon_covering() self._alignments() self._encode()
@staticmethod def _narcs(f: Fst) -> int: """Computes the number of arcs in an FST.""" return sum(f.num_arcs(state) for state in f.states()) def _lexicon_covering(self, input_path=None, output_path=None) -> None: """Builds covering grammar and lexicon FARs.""" # Sets of labels for the covering grammar. with mfa_open( self.working_log_directory.joinpath("covering_grammar.log"), "w" ) as log_file: if input_path is None: input_path = self.input_path if output_path is None: output_path = self.output_path com = [ thirdparty_binary("farcompilestrings"), "--fst_type=compact", ] if self.input_token_type != "utf8": com.append("--token_type=symbol") com.append( f"--symbols={self.input_token_type}", ) com.append("--unknown_symbol=<unk>") else: com.append("--token_type=utf8") com.extend([input_path, self.input_far_path]) print(" ".join(map(str, com)), file=log_file) subprocess.check_call(com, env=os.environ, stderr=log_file, stdout=log_file) com = [ thirdparty_binary("farcompilestrings"), "--fst_type=compact", "--token_type=symbol", f"--symbols={self.phone_symbol_table_path}", output_path, self.output_far_path, ] print(" ".join(map(str, com)), file=log_file) subprocess.check_call(com, env=os.environ, stderr=log_file, stdout=log_file) ilabels = _get_far_labels(self.input_far_path) print(ilabels, file=log_file) olabels = _get_far_labels(self.output_far_path) print(olabels, file=log_file) cg = pywrapfst.VectorFst() state = cg.add_state() cg.set_start(state) one = for ilabel, olabel in itertools.product(ilabels, olabels): cg.add_arc(state, pywrapfst.Arc(ilabel, olabel, one, state)) # Handles epsilons, carefully avoiding adding a useless 0:0 label. if self.insertions: for olabel in olabels: cg.add_arc(state, pywrapfst.Arc(0, olabel, one, state)) if self.deletions: for ilabel in ilabels: cg.add_arc(state, pywrapfst.Arc(ilabel, 0, one, state)) cg.set_final(state) assert cg.verify(), "Label acceptor is ill-formed" cg.write(self.cg_path) def _alignments(self) -> None: """Trains the aligner and constructs the alignments FAR.""" if not os.path.exists(self.align_path):"Training aligner") train_opts = [] if self.batch_size: train_opts.append(f"--batch_size={self.batch_size}") if train_opts.append(f"--delta={}") if self.fst_default_cache_gc: train_opts.append(f"--fst_default_cache_gc={self.fst_default_cache_gc}") if self.fst_default_cache_gc_limit: train_opts.append( f"--fst_default_cache_gc_limit={self.fst_default_cache_gc_limit}" ) if self.alpha: train_opts.append(f"--alpha={self.alpha}") if self.num_iterations: train_opts.append(f"--max_iters={self.num_iterations}") # Constructs the actual command vectors (plus an index for logging # purposes). random.seed(config.SEED) starts = [ ( RandomStart( idx, seed, self.input_far_path, self.output_far_path, self.cg_path, self.working_directory, train_opts, ) ) for (idx, seed) in enumerate( random.sample(range(1, RAND_MAX), self.random_starts), 1 ) ] stopped = threading.Event() num_commands = len(starts) job_queue = Queue() fst_likelihoods = {} # Actually runs starts."Calculating alignments...") begin = time.time() with tqdm(total=num_commands * self.num_iterations, disable=config.QUIET) as pbar: for start in starts: job_queue.put(start) error_dict = {} return_queue = Queue() procs = [] for i in range(config.NUM_JOBS): log_path = self.working_log_directory.joinpath(f"baumwelch.{i}.log") p = RandomStartWorker( i, job_queue, return_queue, log_path, stopped, ) procs.append(p) p.start() while True: try: result = return_queue.get(timeout=1) if isinstance(result, Exception): error_dict[getattr(result, "job_name", 0)] = result continue if stopped.is_set(): continue except queue.Empty: for proc in procs: if not proc.finished.is_set(): break else: break continue if isinstance(result, int): pbar.update(result) else: fst_likelihoods[result[0]] = result[1] return_queue.task_done() for p in procs: p.join() if error_dict: raise PyniniAlignmentError(error_dict) (best_fst, best_likelihood) = min(fst_likelihoods.items(), key=operator.itemgetter(1))"Best likelihood: {best_likelihood}") logger.debug( f"Ran {self.random_starts} random starts in {time.time() - begin:.3f} seconds" ) # Moves best likelihood solution to the requested location. shutil.move(best_fst, self.align_path) cmd = [thirdparty_binary("baumwelchdecode")] if self.fst_default_cache_gc: cmd.append(f"--fst_default_cache_gc={self.fst_default_cache_gc}") if self.fst_default_cache_gc_limit: cmd.append(f"--fst_default_cache_gc_limit={self.fst_default_cache_gc_limit}") cmd.append(self.input_far_path) cmd.append(self.output_far_path) cmd.append(self.align_path) cmd.append(self.afst_path) cmd = [str(x) for x in cmd] logger.debug(f"Subprocess call: {cmd}") subprocess.check_call(cmd, env=os.environ)"Completed computing alignments!") def _encode(self) -> None: """Encodes the alignments.""""Encoding the alignments as FSAs") subprocess.check_call( [ thirdparty_binary("farencode"), "--encode_labels", self.afst_path, self.encoder_path, self.far_path, ], env=os.environ, ) temp_far_path = self.far_path.with_stem("far_temp") far_reader = pynini.Far(self.far_path, mode="r") far_writer = pynini.Far(temp_far_path, mode="w") for key, fst in far_reader: fst = pynini.arcmap(fst, map_type="rmweight") far_writer.add(key, fst) far_writer.close() far_reader.close() del far_reader del far_writer del fst self.far_path.unlink() temp_far_path.rename(self.far_path)"Success! FAR path: {self.far_path}; encoder path: {self.encoder_path}")
[docs] class PyniniTrainer( MultispeakerDictionaryMixin, PyniniTrainerMixin, G2PTrainer, TopLevelMfaWorker ): """ Top-level G2P trainer that uses Pynini functionality See Also -------- :class:`~montreal_forced_aligner.g2p.trainer.G2PTrainer` For base G2P training parameters :class:`` For top-level parameters """ def __init__( self, **kwargs, ): self._data_source = os.path.splitext(os.path.basename(kwargs["dictionary_path"]))[0] super().__init__(**kwargs) self._fst_path = None self._sym_path = None self.position_dependent_phones = False self.wer = None self.ler = None @property def data_directory(self) -> str: """Data directory for trainer""" return self.working_directory @property def configuration(self) -> MetaDict: """Configuration for G2P trainer""" config = super().configuration config.update({"dictionary_path": str(self.dictionary_model.path)}) return config
[docs] def setup(self) -> None: """Setup for G2P training""" super().setup() self.create_new_current_workflow(WorkflowType.train_g2p) wf = self.current_workflow if wf.done:"G2P training already done, skipping.") return self.dictionary_setup() os.makedirs(self.phones_dir, exist_ok=True) self.phone_table.write_text(str(self.phone_symbol_table_path)) self._write_grapheme_symbol_table() os.makedirs(self.working_log_directory, exist_ok=True) self.initialize_training() self.initialized = True
@property def meta(self) -> MetaDict: """Metadata for exported G2P model""" from datetime import datetime from ..utils import get_mfa_version meta = { "version": get_mfa_version(), "architecture": self.architecture, "train_date": str(, "phones": sorted(self.non_silence_phones), "graphemes": self.g2p_training_graphemes, "evaluation": {}, "training": { "num_words": len(self.g2p_training_dictionary), "num_graphemes": len(self.g2p_training_graphemes), "num_phones": len(self.non_silence_phones), }, } if self.model_version is not None: meta["version"] = self.model_version if self.evaluation_mode: meta["evaluation"]["num_words"] = len(self.g2p_validation_dictionary) meta["evaluation"]["word_error_rate"] = self.wer meta["evaluation"]["phone_error_rate"] = self.ler return meta
[docs] def initialize_training(self) -> None: """Initialize training G2P model""" random.seed(config.SEED) self._sym_path = self.phone_symbol_table_path self.output_token_type = pywrapfst.SymbolTable.read_text(self.phone_symbol_table_path) with self.session() as session: self.g2p_training_dictionary = {} pronunciations = ( session.query(Word.word, Pronunciation.pronunciation) .join(Pronunciation.word) .filter(Word.word_type.in_(WordType.speech_types())) ) for w, p in pronunciations: if w not in self.g2p_training_dictionary: self.g2p_training_dictionary[w] = set() self.g2p_training_dictionary[w].add(p) if self.evaluation_mode: word_dict = self.g2p_training_dictionary words = sorted(word_dict.keys()) total_items = len(words) validation_items = int(total_items * self.validation_proportion) validation_words = random.sample(words, validation_items) self.g2p_training_dictionary = { k: v for k, v in word_dict.items() if k not in validation_words } self.g2p_validation_dictionary = { k: v for k, v in word_dict.items() if k in validation_words } if config.DEBUG: with mfa_open( self.working_directory.joinpath("validation_set.txt"), "w", encoding="utf8", ) as f: for word in self.g2p_validation_dictionary: f.write(word + "\n") with mfa_open(self.input_path, "w") as inf, mfa_open(self.output_path, "w") as outf: for word, pronunciations in self.g2p_training_dictionary.items(): if re.match(r"\W", word) is not None: continue self.g2p_training_graphemes.update(word) for p in pronunciations: self.g2p_training_phones.update(p.split()) print(word, file=inf) print(p, file=outf) logger.debug(f"Graphemes in training data: {sorted(self.g2p_training_graphemes)}") logger.debug(f"Phones in training data: {sorted(self.g2p_training_phones)}") if self.evaluation_mode: for word, pronunciations in self.g2p_validation_dictionary.items(): self.g2p_validation_graphemes.update(word) for p in pronunciations: self.g2p_validation_phones.update(p.split()) logger.debug( f"Graphemes in validation data: {sorted(self.g2p_validation_graphemes)}" ) logger.debug(f"Phones in validation data: {sorted(self.g2p_validation_phones)}") grapheme_diff = sorted(self.g2p_validation_graphemes - self.g2p_training_graphemes) phone_diff = sorted(self.g2p_validation_phones - self.g2p_training_phones) if grapheme_diff: logger.debug( f"The following graphemes appear only in the validation set: {', '.join(grapheme_diff)}" ) if phone_diff: logger.debug( f"The following phones appear only in the validation set: {', '.join(phone_diff)}" )
[docs] def clean_up(self) -> None: """ Clean up temporary files """ if config.DEBUG: return for name in os.listdir(self.working_directory): path = self.working_directory.joinpath(name) if os.path.isdir(path): shutil.rmtree(path, ignore_errors=True) elif not name.endswith(".log"): os.remove(path)
[docs] def export_model(self, output_model_path: Path) -> None: """ Export G2P model to specified path Parameters ---------- output_model_path: :class:`~pathlib.Path` Path to export model """ directory = output_model_path.parent directory.mkdir(parents=True, exist_ok=True) models_temp_dir = self.working_directory.joinpath("model_archive_temp") model = G2PModel.empty(output_model_path.stem, root_directory=models_temp_dir) model.add_meta_file(self) model.add_fst_model(self.working_directory) model.add_sym_path(self.working_directory) if directory: os.makedirs(directory, exist_ok=True) model.dump(output_model_path) model.clean_up() # self.clean_up()"Saved model to {output_model_path}")
[docs] def train(self) -> None: """ Train a G2P model """ os.makedirs(self.working_log_directory, exist_ok=True) begin = time.time() if os.path.exists(self.far_path) and os.path.exists(self.encoder_path):"Alignment already done, skipping!") else: self.align_g2p() logger.debug( f"Aligning {len(self.g2p_training_dictionary)} words took {time.time() - begin:.3f} seconds" ) begin = time.time() self.generate_model() logger.debug( f"Generating model for {len(self.g2p_training_dictionary)} words took {time.time() - begin:.3f} seconds" ) self.finalize_training()
[docs] def finalize_training(self) -> None: """Finalize training""" shutil.copyfile(self.fst_path, self.working_directory.joinpath("model.fst")) shutil.copyfile( self.phone_symbol_table_path, self.working_directory.joinpath("phones.txt") ) if self.evaluation_mode: self.evaluate_g2p_model()
[docs] def evaluate_g2p_model(self) -> None: """ Validate the G2P model against held out data """ temp_model_path = self.working_log_directory.joinpath("") self.export_model(temp_model_path) gen = PyniniValidator( g2p_model_path=temp_model_path, word_list=list(self.g2p_validation_dictionary.keys()), num_jobs=config.NUM_JOBS, num_pronunciations=self.num_pronunciations, ) gen.evaluate_g2p_model(self.g2p_training_dictionary) self.wer = gen.wer self.ler = gen.ler