Source code for montreal_forced_aligner.acoustic_modeling.trainer

"""Class definitions for trainable aligners"""
from __future__ import annotations

import collections
import json
import logging
import multiprocessing as mp
import os
import re
import shutil
import subprocess
import time
import typing
from pathlib import Path
from queue import Empty
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

from sqlalchemy.orm import Session, joinedload, subqueryload
from tqdm.rich import tqdm

from montreal_forced_aligner.abc import KaldiFunction, ModelExporterMixin, TopLevelMfaWorker
from montreal_forced_aligner.config import GLOBAL_CONFIG
from montreal_forced_aligner.data import MfaArguments, WorkflowType
from montreal_forced_aligner.db import (
    CorpusWorkflow,
    Dictionary,
    Job,
    Speaker,
    Utterance,
    bulk_update,
)
from montreal_forced_aligner.exceptions import ConfigError, KaldiProcessingError
from montreal_forced_aligner.helper import load_configuration, mfa_open, parse_old_features
from montreal_forced_aligner.models import AcousticModel, DictionaryModel
from montreal_forced_aligner.transcription.transcriber import TranscriberMixin
from montreal_forced_aligner.utils import (
    KaldiProcessWorker,
    Stopped,
    log_kaldi_errors,
    thirdparty_binary,
)

if TYPE_CHECKING:
    from dataclasses import dataclass

    from montreal_forced_aligner.abc import MetaDict
    from montreal_forced_aligner.acoustic_modeling.base import AcousticModelTrainingMixin
    from montreal_forced_aligner.acoustic_modeling.pronunciation_probabilities import (
        PronunciationProbabilityTrainer,
    )
else:
    from dataclassy import dataclass

__all__ = ["TrainableAligner", "TransitionAccFunction", "TransitionAccArguments"]


logger = logging.getLogger("mfa")


[docs] @dataclass class TransitionAccArguments(MfaArguments): """Arguments for :class:`~montreal_forced_aligner.acoustic_modeling.trainer.TransitionAccFunction`""" model_path: Path
[docs] class TransitionAccFunction(KaldiFunction): """ Multiprocessing function to accumulate transition stats See Also -------- :kaldi_src:`ali-to-post` Relevant Kaldi binary :kaldi_src:`post-to-tacc` Relevant Kaldi binary Parameters ---------- args: :class:`~montreal_forced_aligner.acoustic_modeling.trainer.TransitionAccArguments` Arguments for the function """ done_pattern = re.compile( r"^LOG \(post-to-tacc.*Done computing transition stats over (?P<utterances>\d+) utterances.*$" ) def __init__(self, args: TransitionAccArguments): super().__init__(args) self.model_path = args.model_path def _run(self) -> typing.Generator[typing.Tuple[int, str]]: """Run the function""" with mfa_open(self.log_path, "w") as log_file, Session(self.db_engine()) as session: job = ( session.query(Job) .options(joinedload(Job.corpus, innerjoin=True), subqueryload(Job.dictionaries)) .filter(Job.id == self.job_name) .first() ) workflow: CorpusWorkflow = ( session.query(CorpusWorkflow) .filter(CorpusWorkflow.current == True) # noqa .first() ) for dict_id in job.dictionary_ids: ali_path = job.construct_path(workflow.working_directory, "ali", "ark", dict_id) tacc_path = job.construct_path(workflow.working_directory, "t", "acc", dict_id) ali_post_proc = subprocess.Popen( [ thirdparty_binary("ali-to-post"), f"ark:{ali_path}", "ark:-", ], stdout=subprocess.PIPE, env=os.environ, stderr=log_file, ) tacc_proc = subprocess.Popen( [ thirdparty_binary("post-to-tacc"), self.model_path, "ark:-", tacc_path, ], stdin=ali_post_proc.stdout, env=os.environ, stderr=subprocess.PIPE, encoding="utf8", ) for line in tacc_proc.stderr: log_file.write(line) m = self.done_pattern.match(line.strip()) if m: progress_update = int(m.group("utterances")) yield progress_update self.check_call(tacc_proc)
[docs] class TrainableAligner(TranscriberMixin, TopLevelMfaWorker, ModelExporterMixin): """ Train acoustic model Parameters ---------- training_configuration : list[tuple[str, dict[str, Any]]] Training identifiers and parameters for training blocks phone_set_type: str Type of phone set to use for acoustic modeling See Also -------- :class:`~montreal_forced_aligner.alignment.base.CorpusAligner` For dictionary and corpus parsing parameters and alignment parameters :class:`~montreal_forced_aligner.abc.TopLevelMfaWorker` For top-level parameters :class:`~montreal_forced_aligner.abc.ModelExporterMixin` For model export parameters Attributes ---------- param_dict: dict[str, Any] Parameters to pass to training blocks final_identifier: str Identifier of the final training block current_subset: int Current training block's subset current_acoustic_model: :class:`~montreal_forced_aligner.models.AcousticModel` Acoustic model to use in aligning, based on previous training block training_configs: dict[str, :class:`~montreal_forced_aligner.acoustic_modeling.base.AcousticModelTrainingMixin`] Training blocks """ def __init__( self, training_configuration: List[Tuple[str, Dict[str, Any]]] = None, phone_set_type: str = None, **kwargs, ): self.param_dict = { k: v for k, v in kwargs.items() if not k.endswith("_directory") and not k.endswith("_path") and k not in ["speaker_characters"] } self.final_identifier = None self.current_subset: int = 0 self.current_aligner: Optional[AcousticModelTrainingMixin] = None self.current_trainer: Optional[AcousticModelTrainingMixin] = None self.current_acoustic_model: Optional[AcousticModel] = None super().__init__(**kwargs) if phone_set_type and phone_set_type != "UNKNOWN": self.dictionary_model = DictionaryModel( self.dictionary_model.path, phone_set_type=phone_set_type ) self.phone_set_type = self.dictionary_model.phone_set_type os.makedirs(self.output_directory, exist_ok=True) self.training_configs: Dict[ str, typing.Union[AcousticModelTrainingMixin, PronunciationProbabilityTrainer] ] = {} if training_configuration is None: training_configuration = TrainableAligner.default_training_configurations() for k, v in training_configuration: self.add_config(k, v)
[docs] @classmethod def default_training_configurations(cls) -> List[Tuple[str, Dict[str, Any]]]: """Default MFA training configuration""" training_params = [] training_params.append(("monophone", {"subset": 10000, "boost_silence": 1.25})) training_params.append( ( "triphone", { "subset": 20000, "boost_silence": 1.25, "num_leaves": 2500, "max_gaussians": 10000, }, ) ) training_params.append( ("lda", {"subset": 20000, "num_leaves": 3000, "max_gaussians": 15000}) ) training_params.append( ("sat", {"subset": 20000, "num_leaves": 4000, "max_gaussians": 15000}) ) training_params.append( ("sat", {"subset": 50000, "num_leaves": 5000, "max_gaussians": 40000}) ) training_params.append(("pronunciation_probabilities", {"subset": 50000})) training_params.append( ("sat", {"subset": 150000, "num_leaves": 6000, "max_gaussians": 100000}) ) training_params.append( ( "pronunciation_probabilities", {"subset": 150000, "optional": True}, ) ) training_params.append( ( "sat", { "subset": 0, "num_leaves": 7000, "optional": True, "max_gaussians": 150000, "num_iterations": 20, "quick": True, }, ) ) return training_params
[docs] @classmethod def parse_parameters( cls, config_path: Optional[Path] = None, args: Optional[Dict[str, Any]] = None, unknown_args: Optional[typing.Iterable[str]] = None, ) -> MetaDict: """ Parse configuration parameters from a config file and command line arguments Parameters ---------- config_path: :class:`~pathlib.Path`, optional Path to yaml configuration file args: dict[str, Any] Parsed arguments unknown_args: list[str] Optional list of arguments that were not parsed Returns ------- dict[str, Any] Dictionary of specified configuration parameters """ global_params = {} training_params = [] use_default = True if config_path: data = load_configuration(config_path) training_params = [] for k, v in data.items(): if k == "training": for t in v: for k2, v2 in t.items(): if "features" in v2: global_params.update(parse_old_features(v2["features"])) del v2["features"] training_params.append((k2, v2)) elif k == "features": global_params.update(parse_old_features(v)) else: if v is None and k in cls.nullable_fields: v = [] global_params[k] = v if training_params: use_default = False if use_default: # default training configuration training_params = TrainableAligner.default_training_configurations() if training_params: if training_params[0][0] != "monophone": raise ConfigError("The first round of training must be monophone.") global_params["training_configuration"] = training_params global_params.update(cls.parse_args(args, unknown_args)) return global_params
def setup_trainers(self): self.write_training_information() with self.session() as session: workflows: typing.Dict[str, CorpusWorkflow] = { x.name: x for x in session.query(CorpusWorkflow) } for i, (identifier, config) in enumerate(self.training_configs.items()): if isinstance(config, str): continue config.non_silence_phones = self.non_silence_phones ali_identifier = f"{identifier}_ali" if identifier not in workflows: self.create_new_current_workflow( WorkflowType.acoustic_training, name=identifier ) self.create_new_current_workflow(WorkflowType.alignment, name=ali_identifier) else: wf = workflows[identifier] if wf.dirty and not wf.done: shutil.rmtree(wf.working_directory, ignore_errors=True) ali_wf = workflows[ali_identifier] if ali_wf.dirty and not ali_wf.done: shutil.rmtree(ali_wf.working_directory, ignore_errors=True) if i == 0: wf.current = True session.commit() def filter_training_utterances(self): logger.info("Filtering utterances for training...") with self.session() as session: dictionaries = session.query(Dictionary) for d in dictionaries: update_mapping = [] word_mapping = d.word_mapping utterances = ( session.query(Utterance.id, Utterance.normalized_text) .join(Utterance.speaker) .filter(Utterance.ignored == False) # noqa .filter(Speaker.dictionary_id == d.id) ) for u_id, text in utterances: if not text: update_mapping.append({"id": u_id, "ignored": True}) continue words = text.split() if any(x in word_mapping for x in words): continue update_mapping.append({"id": u_id, "ignored": True}) if update_mapping: bulk_update(session, Utterance, update_mapping) session.commit()
[docs] def setup(self) -> None: """Setup for acoustic model training""" super().setup() if self.initialized: return try: self.load_corpus() self.setup_trainers() self.filter_training_utterances() except Exception as e: if isinstance(e, KaldiProcessingError): log_kaldi_errors(e.error_logs) e.update_log_file() raise self.initialized = True
@property def configuration(self) -> MetaDict: """Configuration for the worker""" config = super().configuration config.update( { "dictionary_path": str(self.dictionary_model.path), "corpus_directory": self.corpus_directory, } ) return config @property def meta(self) -> MetaDict: """Metadata about the final round of training""" return self.training_configs[self.final_identifier].meta
[docs] def add_config(self, train_type: str, params: MetaDict) -> None: """ Add a trainer to the pipeline Parameters ---------- train_type: str Type of trainer to add, one of ``monophone``, ``triphone``, ``lda`` or ``sat`` params: dict[str, Any] Parameters to initialize trainer Raises ------ :class:`~montreal_forced_aligner.exceptions.ConfigError` If an invalid train_type is specified """ from montreal_forced_aligner.acoustic_modeling.lda import LdaTrainer from montreal_forced_aligner.acoustic_modeling.monophone import MonophoneTrainer from montreal_forced_aligner.acoustic_modeling.pronunciation_probabilities import ( # noqa PronunciationProbabilityTrainer, ) from montreal_forced_aligner.acoustic_modeling.sat import SatTrainer from montreal_forced_aligner.acoustic_modeling.triphone import TriphoneTrainer p = {} p.update(self.param_dict) p.update(params) identifier = train_type index = 2 while identifier in self.training_configs: identifier = f"{train_type}_{index}" index += 1 self.final_identifier = identifier if train_type == "monophone": p = { k: v for k, v in p.items() if k in MonophoneTrainer.get_configuration_parameters() } config = MonophoneTrainer(identifier=identifier, worker=self, **p) elif train_type == "triphone": p = {k: v for k, v in p.items() if k in TriphoneTrainer.get_configuration_parameters()} config = TriphoneTrainer(identifier=identifier, worker=self, **p) elif train_type == "lda": p = {k: v for k, v in p.items() if k in LdaTrainer.get_configuration_parameters()} config = LdaTrainer(identifier=identifier, worker=self, **p) elif train_type == "sat": p = {k: v for k, v in p.items() if k in SatTrainer.get_configuration_parameters()} config = SatTrainer(identifier=identifier, worker=self, **p) elif train_type == "pronunciation_probabilities": p = { k: v for k, v in p.items() if k in PronunciationProbabilityTrainer.get_configuration_parameters() } previous_trainer = self.training_configs[list(self.training_configs.keys())[-1]] config = PronunciationProbabilityTrainer( identifier=identifier, previous_trainer=previous_trainer, worker=self, **p ) else: raise ConfigError(f"Invalid training type '{train_type}' in config file") self.training_configs[identifier] = config
[docs] def export_model(self, output_model_path: Path) -> None: """ Export an acoustic model to the specified path Parameters ---------- output_model_path : str Path to save acoustic model """ if "pronunciation_probabilities" in self.training_configs: export_directory = os.path.dirname(output_model_path) if export_directory: os.makedirs(export_directory, exist_ok=True) # self.export_trained_rules( # self.training_configs[self.final_identifier].working_directory # ) with self.session() as session: for d in session.query(Dictionary): base_name = self.dictionary_base_names[d.id] if d.use_g2p: shutil.copyfile( self.phone_symbol_table_path, os.path.join( self.training_configs[self.final_identifier].working_directory, "phones.txt", ), ) shutil.copyfile( self.grapheme_symbol_table_path, os.path.join( self.training_configs[self.final_identifier].working_directory, "graphemes.txt", ), ) shutil.copyfile( d.lexicon_fst_path, os.path.join( self.training_configs[self.final_identifier].working_directory, self.dictionary_base_names[d.id] + ".fst", ), ) shutil.copyfile( d.align_lexicon_path, os.path.join( self.training_configs[self.final_identifier].working_directory, self.dictionary_base_names[d.id] + "_align.fst", ), ) else: output_dictionary_path = os.path.join( export_directory, base_name + ".dict" ) self.export_lexicon( d.id, output_dictionary_path, probability=True, ) self.training_configs[self.final_identifier].export_model(output_model_path) logger.info(f"Saved model to {output_model_path}")
@property def tree_path(self) -> str: """Tree path of the final model""" return self.training_configs[self.final_identifier].tree_path
[docs] def train(self) -> None: """ Run through the training configurations to produce a final acoustic model """ self.setup() previous = None begin = time.time() for trainer in self.training_configs.values(): if self.current_subset is None and trainer.optional: logger.info( "Exiting training early to save time as the corpus is below the subset size for later training stages" ) break if trainer.subset < self.num_utterances: self.current_subset = trainer.subset else: self.current_subset = None trainer.subset = 0 self.subset_directory(self.current_subset) if previous is not None: self.set_current_workflow(f"{previous.identifier}_ali") self.current_aligner = previous os.makedirs(self.working_directory, exist_ok=True) self.current_acoustic_model = AcousticModel( previous.exported_model_path, self.working_directory ) self.align() self.set_current_workflow(trainer.identifier) if trainer.identifier.startswith("pronunciation_probabilities"): trainer.train_pronunciation_probabilities() else: trainer.train() previous = trainer self.final_identifier = trainer.identifier self.current_subset = None self.current_aligner = previous self.set_current_workflow(f"{previous.identifier}_ali") os.makedirs(self.working_log_directory, exist_ok=True) self.current_acoustic_model = AcousticModel( previous.exported_model_path, self.working_directory ) self.acoustic_model = AcousticModel(previous.exported_model_path, self.working_directory) self.align() self.finalize_training() counts_path = self.working_directory.joinpath("phone_pdf.counts") new_counts_path = os.path.join(previous.working_directory, "phone_pdf.counts") if not os.path.exists(new_counts_path): shutil.copyfile(counts_path, new_counts_path) phone_lm_path = os.path.join(self.phones_dir, "phone_lm.fst") new_phone_lm_path = os.path.join(previous.working_directory, "phone_lm.fst") if not os.path.exists(new_phone_lm_path) and os.path.exists(phone_lm_path): shutil.copyfile(phone_lm_path, new_phone_lm_path) logger.info(f"Completed training in {time.time()-begin} seconds!")
[docs] def transition_acc_arguments(self) -> List[TransitionAccArguments]: """ Generate Job arguments for :class:`~montreal_forced_aligner.acoustic_modeling.trainer.TransitionAccArguments` Returns ------- list[:class:`~montreal_forced_aligner.acoustic_modeling.trainer.TransitionAccArguments`] Arguments for processing """ return [ TransitionAccArguments( j.id, getattr(self, "db_string", ""), self.working_log_directory.joinpath(f"test_utterances.{j.id}.log"), self.model_path, ) for j in self.jobs ]
[docs] def compute_phone_pdf_counts(self) -> None: """ Calculate the counts of pdfs corresponding to phones """ try: logger.info("Accumulating transition stats...") begin = time.time() log_directory = self.working_log_directory os.makedirs(log_directory, exist_ok=True) arguments = self.transition_acc_arguments() with tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: if GLOBAL_CONFIG.use_mp: error_dict = {} return_queue = mp.Queue() stopped = Stopped() procs = [] for i, args in enumerate(arguments): function = TransitionAccFunction(args) p = KaldiProcessWorker(i, return_queue, function, stopped) procs.append(p) p.start() while True: try: result = return_queue.get(timeout=1) if stopped.stop_check(): continue except Empty: for proc in procs: if not proc.finished.stop_check(): break else: break continue if isinstance(result, KaldiProcessingError): error_dict[result.job_name] = result continue pbar.update(result) for p in procs: p.join() if error_dict: for v in error_dict.values(): raise v else: logger.debug("Not using multiprocessing...") for args in arguments: function = TransitionAccFunction(args) for result in function.run(): pbar.update(result) t_accs = [] for j in self.jobs: for dict_id in j.dictionary_ids: t_accs.append(j.construct_path(self.working_directory, "t", "acc", dict_id)) subprocess.check_call( [ thirdparty_binary("vector-sum"), "--binary=false", *t_accs, self.working_directory.joinpath("final.tacc"), ], stderr=subprocess.DEVNULL, ) for f in t_accs: os.remove(f) smoothing = 1 show_proc = subprocess.Popen( [ thirdparty_binary("show-transitions"), self.phone_symbol_table_path, self.model_path, ], stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, encoding="utf8", env=os.environ, ) phone_pdfs = {} phone, pdf = None, None max_pdf = 0 max_phone = 0 for line in show_proc.stdout: line = line.strip() m = re.match( r"^Transition-state.*phone = (?P<phone>[^ ]+) .*pdf = (?P<pdf>\d+)$", line ) if m: phone = m.group("phone") pdf = int(m.group("pdf")) if pdf > max_pdf: max_pdf = pdf if self.phone_mapping[phone] > max_phone: max_phone = self.phone_mapping[phone] else: m = re.search(r"Transition-id = (?P<transition_id>\d+)", line) if m: transition_id = int(m.group("transition_id")) phone_pdfs[transition_id] = (phone, pdf) with mfa_open(self.working_directory.joinpath("final.tacc"), "r") as f: data = f.read().strip().split()[1:-1] transition_counts = { i: smoothing + int(float(x)) for i, x in enumerate(data) if i != 0 } assert len(transition_counts) == len(phone_pdfs) pdf_counts = collections.Counter() pdf_phone_counts = collections.Counter() phone_pdf_mapping = collections.defaultdict(collections.Counter) for transition_id, (phone, pdf) in phone_pdfs.items(): pdf_counts[pdf] += transition_counts[transition_id] pdf_phone_counts[(phone, pdf)] += transition_counts[transition_id] phone_pdf_mapping[phone][pdf] += transition_counts[transition_id] with mfa_open(self.working_directory.joinpath("phone_pdf.counts"), "w") as f: json.dump(phone_pdf_mapping, f, ensure_ascii=False) logger.debug(f"Accumulating transition stats took {time.time() - begin:.3f} seconds") logger.info("Finished accumulating transition stats!") except Exception as e: if isinstance(e, KaldiProcessingError): log_kaldi_errors(e.error_logs) e.update_log_file() raise
def finalize_training(self): self.compute_phone_pdf_counts() self.collect_alignments() self.train_phone_lm()
[docs] def export_files( self, output_directory: Path, output_format: Optional[str] = None, include_original_text: bool = False, ) -> None: """ Export a TextGrid file for every sound file in the dataset Parameters ---------- output_directory: :class:`~pathlib.Path` Directory to save to output_format: str, optional Format to save alignments, one of 'long_textgrids' (the default), 'short_textgrids', or 'json', passed to praatio include_original_text: bool Flag for including the original text of the corpus files as a tier """ self.align() self.analyze_alignments() super(TrainableAligner, self).export_files( output_directory, output_format, include_original_text )
@property def num_current_utterances(self) -> int: """Number of utterances in the current subset""" if self.current_subset and self.current_subset < self.num_utterances: return self.current_subset return self.num_utterances @property def align_options(self) -> MetaDict: """Alignment options""" if self.current_aligner is not None: return self.current_aligner.align_options return super().align_options
[docs] def align(self) -> None: """ Multiprocessing function that aligns based on the current model. See Also -------- :class:`~montreal_forced_aligner.alignment.multiprocessing.AlignFunction` Multiprocessing helper function for each job :meth:`.AlignMixin.align_arguments` Job method for generating arguments for the helper function :kaldi_steps:`align_si` Reference Kaldi script :kaldi_steps:`align_fmllr` Reference Kaldi script """ wf = self.current_workflow if wf.done: logger.debug(f"Skipping {self.current_aligner.identifier} alignments") return try: self.current_acoustic_model.export_model(self.working_directory) self.uses_speaker_adaptation = False self.compile_train_graphs() self.align_utterances() if self.current_acoustic_model.meta["features"]["uses_speaker_adaptation"]: arguments = self.calc_fmllr_arguments() missing_transforms = False for arg in arguments: for path in arg.trans_paths.values(): if not os.path.exists(path): missing_transforms = True if missing_transforms: assert self.alignment_model_path.suffix == ".alimdl" self.calc_fmllr() self.uses_speaker_adaptation = True assert self.alignment_model_path.suffix == ".mdl" self.align_utterances() if self.current_subset: logger.debug( f"Analyzing alignment diagnostics for {self.current_aligner.identifier} on {self.current_subset} utterances" ) else: logger.debug( f"Analyzing alignment diagnostics for {self.current_aligner.identifier} on the full corpus" ) self.compile_information() with self.session() as session: session.query(CorpusWorkflow).filter(CorpusWorkflow.id == wf.id).update( {"done": True} ) session.commit() except Exception as e: with self.session() as session: session.query(CorpusWorkflow).filter(CorpusWorkflow.id == wf.id).update( {"dirty": True} ) session.commit() if isinstance(e, KaldiProcessingError): log_kaldi_errors(e.error_logs) e.update_log_file() raise
@property def alignment_model_path(self) -> Path: """Current alignment model path""" path = self.working_directory.joinpath("final.alimdl") if os.path.exists(path) and not self.uses_speaker_adaptation: return path return self.model_path @property def model_path(self) -> Path: """Current model path""" if self.current_trainer is not None: return self.current_trainer.model_path return self.working_directory.joinpath("final.mdl") @property def data_directory(self) -> str: """Current data directory based on the trainer's subset""" return self.subset_directory(self.current_subset) @property def working_directory(self) -> Optional[Path]: """Working directory""" if self.current_trainer is not None and not self.current_trainer.training_complete: return self.current_trainer.working_directory if self.current_aligner is None: return None return self.output_directory.joinpath(f"{self.current_aligner.identifier}_ali") @property def working_log_directory(self) -> Optional[str]: """Current log directory""" return self.working_directory.joinpath("log")