Source code for montreal_forced_aligner.acoustic_modeling.lda

"""Class definitions for LDA trainer"""
from __future__ import annotations

import logging
import multiprocessing as mp
import os
import re
import shutil
import subprocess
import typing
from pathlib import Path
from queue import Empty
from typing import TYPE_CHECKING, Dict, List

from tqdm.rich import tqdm

from montreal_forced_aligner.abc import KaldiFunction
from montreal_forced_aligner.acoustic_modeling.triphone import TriphoneTrainer
from montreal_forced_aligner.config import GLOBAL_CONFIG
from montreal_forced_aligner.data import MfaArguments
from montreal_forced_aligner.helper import mfa_open
from montreal_forced_aligner.utils import (
    KaldiProcessWorker,
    Stopped,
    parse_logs,
    thirdparty_binary,
)

if TYPE_CHECKING:
    from montreal_forced_aligner.abc import MetaDict


__all__ = [
    "LdaTrainer",
    "CalcLdaMlltFunction",
    "CalcLdaMlltArguments",
    "LdaAccStatsFunction",
    "LdaAccStatsArguments",
]

logger = logging.getLogger("mfa")


[docs] class LdaAccStatsArguments(MfaArguments): """Arguments for :func:`~montreal_forced_aligner.acoustic_modeling.lda.LdaAccStatsFunction`""" dictionaries: List[str] feature_strings: Dict[str, str] ali_paths: Dict[str, Path] model_path: Path lda_options: MetaDict acc_paths: Dict[str, Path]
[docs] class CalcLdaMlltArguments(MfaArguments): """Arguments for :func:`~montreal_forced_aligner.acoustic_modeling.lda.CalcLdaMlltFunction`""" dictionaries: List[str] feature_strings: Dict[str, str] ali_paths: Dict[str, Path] model_path: Path lda_options: MetaDict macc_paths: Dict[str, Path]
[docs] class LdaAccStatsFunction(KaldiFunction): """ Multiprocessing function to accumulate LDA stats See Also -------- :meth:`.LdaTrainer.lda_acc_stats` Main function that calls this function in parallel :meth:`.LdaTrainer.lda_acc_stats_arguments` Job method for generating arguments for this function :kaldi_src:`ali-to-post` Relevant Kaldi binary :kaldi_src:`weight-silence-post` Relevant Kaldi binary :kaldi_src:`acc-lda` Relevant Kaldi binary Parameters ---------- args: :class:`~montreal_forced_aligner.acoustic_modeling.lda.LdaAccStatsArguments` Arguments for the function """ progress_pattern = re.compile(r"^LOG.*Done (?P<done>\d+) files, failed for (?P<failed>\d+)$") def __init__(self, args: LdaAccStatsArguments): super().__init__(args) self.dictionaries = args.dictionaries self.feature_strings = args.feature_strings self.ali_paths = args.ali_paths self.model_path = args.model_path self.acc_paths = args.acc_paths self.lda_options = args.lda_options def _run(self) -> typing.Generator[typing.Tuple[int, int]]: """Run the function""" with mfa_open(self.log_path, "w") as log_file: for dict_id in self.dictionaries: ali_path = self.ali_paths[dict_id] feature_string = self.feature_strings[dict_id] acc_path = self.acc_paths[dict_id] ali_to_post_proc = subprocess.Popen( [thirdparty_binary("ali-to-post"), f"ark:{ali_path}", "ark:-"], stderr=log_file, stdout=subprocess.PIPE, env=os.environ, ) weight_silence_post_proc = subprocess.Popen( [ thirdparty_binary("weight-silence-post"), "0.0", self.lda_options["silence_csl"], self.model_path, "ark:-", "ark:-", ], stdin=ali_to_post_proc.stdout, stderr=log_file, stdout=subprocess.PIPE, env=os.environ, ) acc_lda_post_proc = subprocess.Popen( [ thirdparty_binary("acc-lda"), f"--rand-prune={self.lda_options['random_prune']}", self.model_path, feature_string, "ark,s,cs:-", acc_path, ], stdin=weight_silence_post_proc.stdout, stderr=subprocess.PIPE, encoding="utf8", env=os.environ, ) for line in acc_lda_post_proc.stderr: log_file.write(line) m = self.progress_pattern.match(line.strip()) if m: yield int(m.group("done")), int(m.group("failed")) self.check_call(acc_lda_post_proc)
[docs] class CalcLdaMlltFunction(KaldiFunction): """ Multiprocessing function for estimating LDA with MLLT. See Also -------- :meth:`.LdaTrainer.calc_lda_mllt` Main function that calls this function in parallel :meth:`.LdaTrainer.calc_lda_mllt_arguments` Job method for generating arguments for this function :kaldi_src:`ali-to-post` Relevant Kaldi binary :kaldi_src:`weight-silence-post` Relevant Kaldi binary :kaldi_src:`gmm-acc-mllt` Relevant Kaldi binary Parameters ---------- args: :class:`~montreal_forced_aligner.acoustic_modeling.lda.CalcLdaMlltArguments` Arguments for the function """ progress_pattern = re.compile(r"^LOG.*Average like for this file.*$") def __init__(self, args: CalcLdaMlltArguments): super().__init__(args) self.dictionaries = args.dictionaries self.feature_strings = args.feature_strings self.ali_paths = args.ali_paths self.model_path = args.model_path self.macc_paths = args.macc_paths self.lda_options = args.lda_options def _run(self) -> typing.Generator[int]: """Run the function""" # Estimating MLLT with mfa_open(self.log_path, "w") as log_file: for dict_id in self.dictionaries: ali_path = self.ali_paths[dict_id] feature_string = self.feature_strings[dict_id] macc_path = self.macc_paths[dict_id] post_proc = subprocess.Popen( [thirdparty_binary("ali-to-post"), f"ark:{ali_path}", "ark:-"], stdout=subprocess.PIPE, stderr=log_file, env=os.environ, ) weight_proc = subprocess.Popen( [ thirdparty_binary("weight-silence-post"), "0.0", self.lda_options["silence_csl"], self.model_path, "ark:-", "ark:-", ], stdin=post_proc.stdout, stdout=subprocess.PIPE, stderr=log_file, env=os.environ, ) acc_proc = subprocess.Popen( [ thirdparty_binary("gmm-acc-mllt"), f"--rand-prune={self.lda_options['random_prune']}", self.model_path, feature_string, "ark,s,cs:-", macc_path, ], stdin=weight_proc.stdout, stderr=subprocess.PIPE, encoding="utf8", env=os.environ, ) for line in acc_proc.stderr: log_file.write(line) m = self.progress_pattern.match(line.strip()) if m: yield 1 self.check_call(acc_proc)
[docs] class LdaTrainer(TriphoneTrainer): """ Triphone trainer Parameters ---------- subset : int Number of utterances to use, defaults to 10000 num_leaves : int Number of states in the decision tree, defaults to 2500 max_gaussians : int Number of gaussians in the decision tree, defaults to 15000 lda_dimension : int Dimensionality of the LDA matrix uses_splices : bool Flag to use spliced and LDA calculation splice_left_context : int or None Number of frames to splice on the left for calculating LDA splice_right_context : int or None Number of frames to splice on the right for calculating LDA random_prune : float This is approximately the ratio by which we will speed up the LDA and MLLT calculations via randomized pruning See Also -------- :class:`~montreal_forced_aligner.acoustic_modeling.triphone.TriphoneTrainer` For acoustic model training parsing parameters Attributes ---------- mllt_iterations : list List of iterations to perform MLLT estimation """ def __init__( self, subset: int = 10000, num_leaves: int = 2500, max_gaussians=15000, lda_dimension: int = 40, uses_splices: bool = True, splice_left_context: int = 3, splice_right_context: int = 3, random_prune=4.0, boost_silence: float = 1.0, power: float = 0.25, **kwargs, ): super().__init__( boost_silence=boost_silence, power=power, subset=subset, num_leaves=num_leaves, max_gaussians=max_gaussians, **kwargs, ) self.lda_dimension = lda_dimension self.random_prune = random_prune self.uses_splices = uses_splices self.splice_left_context = splice_left_context self.splice_right_context = splice_right_context
[docs] def lda_acc_stats_arguments(self) -> List[LdaAccStatsArguments]: """ Generate Job arguments for :func:`~montreal_forced_aligner.acoustic_modeling.lda.LdaAccStatsFunction` Returns ------- list[:class:`~montreal_forced_aligner.acoustic_modeling.lda.LdaAccStatsArguments`] Arguments for processing """ arguments = [] for j in self.jobs: feat_strings = {} for d_id in j.dictionary_ids: feat_strings[d_id] = j.construct_feature_proc_string( self.working_directory, d_id, self.feature_options["uses_splices"], self.feature_options["splice_left_context"], self.feature_options["splice_right_context"], self.feature_options["uses_speaker_adaptation"], ) arguments.append( LdaAccStatsArguments( j.id, getattr(self, "db_string", ""), self.working_log_directory.joinpath(f"lda_acc_stats.{j.id}.log"), j.dictionary_ids, feat_strings, j.construct_path_dictionary( self.previous_aligner.working_directory, "ali", "ark" ), self.previous_aligner.alignment_model_path, self.lda_options, j.construct_path_dictionary(self.working_directory, "lda", "acc"), ) ) return arguments
[docs] def calc_lda_mllt_arguments(self) -> List[CalcLdaMlltArguments]: """ Generate Job arguments for :func:`~montreal_forced_aligner.acoustic_modeling.lda.CalcLdaMlltFunction` Returns ------- list[:class:`~montreal_forced_aligner.acoustic_modeling.lda.CalcLdaMlltArguments`] Arguments for processing """ arguments = [] for j in self.jobs: feat_strings = {} for d_id in j.dictionary_ids: feat_strings[d_id] = j.construct_feature_proc_string( self.working_directory, d_id, self.feature_options["uses_splices"], self.feature_options["splice_left_context"], self.feature_options["splice_right_context"], self.feature_options["uses_speaker_adaptation"], ) arguments.append( CalcLdaMlltArguments( j.id, getattr(self, "db_string", ""), os.path.join( self.working_log_directory, f"lda_mllt.{self.iteration}.{j.id}.log" ), j.dictionary_ids, feat_strings, j.construct_path_dictionary(self.working_directory, "ali", "ark"), self.model_path, self.lda_options, j.construct_path_dictionary(self.working_directory, "lda", "macc"), ) ) return arguments
@property def train_type(self) -> str: """Training identifier""" return "lda" @property def lda_options(self) -> MetaDict: """Options for computing LDA""" return { "lda_dimension": self.lda_dimension, "random_prune": self.random_prune, "silence_csl": self.silence_csl, "splice_left_context": self.splice_left_context, "splice_right_context": self.splice_right_context, }
[docs] def compute_calculated_properties(self) -> None: """Generate realignment iterations, MLLT estimation iterations, and initial gaussians based on configuration""" super().compute_calculated_properties() self.mllt_iterations = [2, 4, 6, 12]
[docs] def lda_acc_stats(self) -> None: """ Multiprocessing function that accumulates LDA statistics. See Also -------- :func:`~montreal_forced_aligner.acoustic_modeling.lda.LdaAccStatsFunction` Multiprocessing helper function for each job :meth:`.LdaTrainer.lda_acc_stats_arguments` Job method for generating arguments for the helper function :kaldi_src:`est-lda` Relevant Kaldi binary :kaldi_steps:`train_lda_mllt` Reference Kaldi script """ worker_lda_path = os.path.join(self.worker.working_directory, "lda.mat") lda_path = self.working_directory.joinpath("lda.mat") if os.path.exists(worker_lda_path): os.remove(worker_lda_path) arguments = self.lda_acc_stats_arguments() with tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: if GLOBAL_CONFIG.use_mp: error_dict = {} return_queue = mp.Queue() stopped = Stopped() procs = [] for i, args in enumerate(arguments): function = LdaAccStatsFunction(args) p = KaldiProcessWorker(i, return_queue, function, stopped) procs.append(p) p.start() while True: try: result = return_queue.get(timeout=1) if isinstance(result, Exception): error_dict[getattr(result, "job_name", 0)] = result continue if stopped.stop_check(): continue except Empty: for proc in procs: if not proc.finished.stop_check(): break else: break continue done, errors = result pbar.update(done + errors) for p in procs: p.join() if error_dict: for v in error_dict.values(): raise v else: for args in arguments: function = LdaAccStatsFunction(args) for done, errors in function.run(): pbar.update(done + errors) log_path = self.working_log_directory.joinpath("lda_est.log") acc_list = [] for x in arguments: acc_list.extend(x.acc_paths.values()) with mfa_open(log_path, "w") as log_file: est_lda_proc = subprocess.Popen( [ thirdparty_binary("est-lda"), f"--dim={self.lda_dimension}", lda_path, ] + acc_list, stderr=log_file, env=os.environ, ) est_lda_proc.communicate() shutil.copyfile( lda_path, worker_lda_path, )
def _trainer_initialization(self) -> None: """Initialize LDA training""" self.uses_splices = True self.worker.uses_splices = True if self.initialized: return self.lda_acc_stats() self.tree_stats() self._setup_tree(initial_mix_up=False) self.compile_train_graphs() self.convert_alignments() os.rename(self.model_path, self.next_model_path)
[docs] def calc_lda_mllt(self) -> None: """ Multiprocessing function that calculates LDA+MLLT transformations. See Also -------- :func:`~montreal_forced_aligner.acoustic_modeling.lda.CalcLdaMlltFunction` Multiprocessing helper function for each job :meth:`.LdaTrainer.calc_lda_mllt_arguments` Job method for generating arguments for the helper function :kaldi_src:`est-mllt` Relevant Kaldi binary :kaldi_src:`gmm-transform-means` Relevant Kaldi binary :kaldi_src:`compose-transforms` Relevant Kaldi binary :kaldi_steps:`train_lda_mllt` Reference Kaldi script """ logger.info("Re-calculating LDA...") arguments = self.calc_lda_mllt_arguments() with tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: if GLOBAL_CONFIG.use_mp: error_dict = {} return_queue = mp.Queue() stopped = Stopped() procs = [] for i, args in enumerate(arguments): function = CalcLdaMlltFunction(args) p = KaldiProcessWorker(i, return_queue, function, stopped) procs.append(p) p.start() while True: try: result = return_queue.get(timeout=1) if isinstance(result, Exception): error_dict[getattr(result, "job_name", 0)] = result continue if stopped.stop_check(): continue except Empty: for proc in procs: if not proc.finished.stop_check(): break else: break continue pbar.update(1) for p in procs: p.join() if error_dict: for v in error_dict.values(): raise v else: for args in arguments: function = CalcLdaMlltFunction(args) for _ in function.run(): pbar.update(1) log_path = os.path.join( self.working_log_directory, f"transform_means.{self.iteration}.log" ) previous_mat_path = self.working_directory.joinpath("lda.mat") new_mat_path = self.working_directory.joinpath("lda_new.mat") composed_path = self.working_directory.joinpath("lda_composed.mat") with mfa_open(log_path, "a") as log_file: macc_list = [] for x in arguments: macc_list.extend(x.macc_paths.values()) subprocess.call( [thirdparty_binary("est-mllt"), new_mat_path] + macc_list, stderr=log_file, env=os.environ, ) subprocess.call( [ thirdparty_binary("gmm-transform-means"), new_mat_path, self.model_path, self.model_path, ], stderr=log_file, env=os.environ, ) if os.path.exists(previous_mat_path): subprocess.call( [ thirdparty_binary("compose-transforms"), new_mat_path, previous_mat_path, composed_path, ], stderr=log_file, env=os.environ, ) os.remove(previous_mat_path) os.rename(composed_path, previous_mat_path) else: os.rename(new_mat_path, previous_mat_path)
[docs] def train_iteration(self) -> None: """ Run a single LDA training iteration """ if os.path.exists(self.next_model_path): if self.iteration <= self.final_gaussian_iteration: self.increment_gaussians() self.iteration += 1 return if self.iteration in self.realignment_iterations: self.align_iteration() if self.iteration in self.mllt_iterations: self.calc_lda_mllt() self.acc_stats() parse_logs(self.working_log_directory) if self.iteration <= self.final_gaussian_iteration: self.increment_gaussians() self.iteration += 1