Source code for montreal_forced_aligner.acoustic_modeling.lda

"""Class definitions for LDA trainer"""
from __future__ import annotations

import logging
import multiprocessing as mp
import os
import re
import shutil
import subprocess
import typing
from queue import Empty
from typing import TYPE_CHECKING, Dict, List

import tqdm

from montreal_forced_aligner.abc import KaldiFunction
from montreal_forced_aligner.acoustic_modeling.triphone import TriphoneTrainer
from montreal_forced_aligner.config import GLOBAL_CONFIG
from montreal_forced_aligner.data import MfaArguments
from montreal_forced_aligner.helper import mfa_open
from montreal_forced_aligner.utils import (
    KaldiProcessWorker,
    Stopped,
    parse_logs,
    thirdparty_binary,
)

if TYPE_CHECKING:
    from montreal_forced_aligner.abc import MetaDict


__all__ = [
    "LdaTrainer",
    "CalcLdaMlltFunction",
    "CalcLdaMlltArguments",
    "LdaAccStatsFunction",
    "LdaAccStatsArguments",
]

logger = logging.getLogger("mfa")


[docs] class LdaAccStatsArguments(MfaArguments): """Arguments for :func:`~montreal_forced_aligner.acoustic_modeling.lda.LdaAccStatsFunction`""" dictionaries: List[str] feature_strings: Dict[str, str] ali_paths: Dict[str, str] model_path: str lda_options: MetaDict acc_paths: Dict[str, str]
[docs] class CalcLdaMlltArguments(MfaArguments): """Arguments for :func:`~montreal_forced_aligner.acoustic_modeling.lda.CalcLdaMlltFunction`""" dictionaries: List[str] feature_strings: Dict[str, str] ali_paths: Dict[str, str] model_path: str lda_options: MetaDict macc_paths: Dict[str, str]
[docs] class LdaAccStatsFunction(KaldiFunction): """ Multiprocessing function to accumulate LDA stats See Also -------- :meth:`.LdaTrainer.lda_acc_stats` Main function that calls this function in parallel :meth:`.LdaTrainer.lda_acc_stats_arguments` Job method for generating arguments for this function :kaldi_src:`ali-to-post` Relevant Kaldi binary :kaldi_src:`weight-silence-post` Relevant Kaldi binary :kaldi_src:`acc-lda` Relevant Kaldi binary Parameters ---------- args: :class:`~montreal_forced_aligner.acoustic_modeling.lda.LdaAccStatsArguments` Arguments for the function """ progress_pattern = re.compile(r"^LOG.*Done (?P<done>\d+) files, failed for (?P<failed>\d+)$") def __init__(self, args: LdaAccStatsArguments): super().__init__(args) self.dictionaries = args.dictionaries self.feature_strings = args.feature_strings self.ali_paths = args.ali_paths self.model_path = args.model_path self.acc_paths = args.acc_paths self.lda_options = args.lda_options def _run(self) -> typing.Generator[typing.Tuple[int, int]]: """Run the function""" with mfa_open(self.log_path, "w") as log_file: for dict_id in self.dictionaries: ali_path = self.ali_paths[dict_id] feature_string = self.feature_strings[dict_id] acc_path = self.acc_paths[dict_id] ali_to_post_proc = subprocess.Popen( [thirdparty_binary("ali-to-post"), f"ark:{ali_path}", "ark:-"], stderr=log_file, stdout=subprocess.PIPE, env=os.environ, ) weight_silence_post_proc = subprocess.Popen( [ thirdparty_binary("weight-silence-post"), "0.0", self.lda_options["silence_csl"], self.model_path, "ark:-", "ark:-", ], stdin=ali_to_post_proc.stdout, stderr=log_file, stdout=subprocess.PIPE, env=os.environ, ) acc_lda_post_proc = subprocess.Popen( [ thirdparty_binary("acc-lda"), f"--rand-prune={self.lda_options['random_prune']}", self.model_path, feature_string, "ark,s,cs:-", acc_path, ], stdin=weight_silence_post_proc.stdout, stderr=subprocess.PIPE, encoding="utf8", env=os.environ, ) for line in acc_lda_post_proc.stderr: log_file.write(line) m = self.progress_pattern.match(line.strip()) if m: yield int(m.group("done")), int(m.group("failed")) self.check_call(acc_lda_post_proc)
[docs] class CalcLdaMlltFunction(KaldiFunction): """ Multiprocessing function for estimating LDA with MLLT. See Also -------- :meth:`.LdaTrainer.calc_lda_mllt` Main function that calls this function in parallel :meth:`.LdaTrainer.calc_lda_mllt_arguments` Job method for generating arguments for this function :kaldi_src:`ali-to-post` Relevant Kaldi binary :kaldi_src:`weight-silence-post` Relevant Kaldi binary :kaldi_src:`gmm-acc-mllt` Relevant Kaldi binary Parameters ---------- args: :class:`~montreal_forced_aligner.acoustic_modeling.lda.CalcLdaMlltArguments` Arguments for the function """ progress_pattern = re.compile(r"^LOG.*Average like for this file.*$") def __init__(self, args: CalcLdaMlltArguments): super().__init__(args) self.dictionaries = args.dictionaries self.feature_strings = args.feature_strings self.ali_paths = args.ali_paths self.model_path = args.model_path self.macc_paths = args.macc_paths self.lda_options = args.lda_options def _run(self) -> typing.Generator[int]: """Run the function""" # Estimating MLLT with mfa_open(self.log_path, "w") as log_file: for dict_id in self.dictionaries: ali_path = self.ali_paths[dict_id] feature_string = self.feature_strings[dict_id] macc_path = self.macc_paths[dict_id] post_proc = subprocess.Popen( [thirdparty_binary("ali-to-post"), f"ark:{ali_path}", "ark:-"], stdout=subprocess.PIPE, stderr=log_file, env=os.environ, ) weight_proc = subprocess.Popen( [ thirdparty_binary("weight-silence-post"), "0.0", self.lda_options["silence_csl"], self.model_path, "ark:-", "ark:-", ], stdin=post_proc.stdout, stdout=subprocess.PIPE, stderr=log_file, env=os.environ, ) acc_proc = subprocess.Popen( [ thirdparty_binary("gmm-acc-mllt"), f"--rand-prune={self.lda_options['random_prune']}", self.model_path, feature_string, "ark,s,cs:-", macc_path, ], stdin=weight_proc.stdout, stderr=subprocess.PIPE, encoding="utf8", env=os.environ, ) for line in acc_proc.stderr: log_file.write(line) m = self.progress_pattern.match(line.strip()) if m: yield 1 self.check_call(acc_proc)
[docs] class LdaTrainer(TriphoneTrainer): """ Triphone trainer Parameters ---------- subset : int Number of utterances to use, defaults to 10000 num_leaves : int Number of states in the decision tree, defaults to 2500 max_gaussians : int Number of gaussians in the decision tree, defaults to 15000 lda_dimension : int Dimensionality of the LDA matrix uses_splices : bool Flag to use spliced and LDA calculation splice_left_context : int or None Number of frames to splice on the left for calculating LDA splice_right_context : int or None Number of frames to splice on the right for calculating LDA random_prune : float This is approximately the ratio by which we will speed up the LDA and MLLT calculations via randomized pruning See Also -------- :class:`~montreal_forced_aligner.acoustic_modeling.triphone.TriphoneTrainer` For acoustic model training parsing parameters Attributes ---------- mllt_iterations : list List of iterations to perform MLLT estimation """ def __init__( self, subset: int = 10000, num_leaves: int = 2500, max_gaussians=15000, lda_dimension: int = 40, uses_splices: bool = True, splice_left_context: int = 3, splice_right_context: int = 3, random_prune=4.0, boost_silence: float = 1.0, power: float = 0.25, **kwargs, ): super().__init__( boost_silence=boost_silence, power=power, subset=subset, num_leaves=num_leaves, max_gaussians=max_gaussians, **kwargs, ) self.lda_dimension = lda_dimension self.random_prune = random_prune self.uses_splices = uses_splices self.splice_left_context = splice_left_context self.splice_right_context = splice_right_context
[docs] def lda_acc_stats_arguments(self) -> List[LdaAccStatsArguments]: """ Generate Job arguments for :func:`~montreal_forced_aligner.acoustic_modeling.lda.LdaAccStatsFunction` Returns ------- list[:class:`~montreal_forced_aligner.acoustic_modeling.lda.LdaAccStatsArguments`] Arguments for processing """ arguments = [] for j in self.jobs: feat_strings = {} for d_id in j.dictionary_ids: feat_strings[d_id] = j.construct_feature_proc_string( self.working_directory, d_id, self.feature_options["uses_splices"], self.feature_options["splice_left_context"], self.feature_options["splice_right_context"], self.feature_options["uses_speaker_adaptation"], ) arguments.append( LdaAccStatsArguments( j.id, getattr(self, "db_string", ""), os.path.join(self.working_log_directory, f"lda_acc_stats.{j.id}.log"), j.dictionary_ids, feat_strings, j.construct_path_dictionary( self.previous_aligner.working_directory, "ali", "ark" ), self.previous_aligner.alignment_model_path, self.lda_options, j.construct_path_dictionary(self.working_directory, "lda", "acc"), ) ) return arguments
[docs] def calc_lda_mllt_arguments(self) -> List[CalcLdaMlltArguments]: """ Generate Job arguments for :func:`~montreal_forced_aligner.acoustic_modeling.lda.CalcLdaMlltFunction` Returns ------- list[:class:`~montreal_forced_aligner.acoustic_modeling.lda.CalcLdaMlltArguments`] Arguments for processing """ arguments = [] for j in self.jobs: feat_strings = {} for d_id in j.dictionary_ids: feat_strings[d_id] = j.construct_feature_proc_string( self.working_directory, d_id, self.feature_options["uses_splices"], self.feature_options["splice_left_context"], self.feature_options["splice_right_context"], self.feature_options["uses_speaker_adaptation"], ) arguments.append( CalcLdaMlltArguments( j.id, getattr(self, "db_string", ""), os.path.join( self.working_log_directory, f"lda_mllt.{self.iteration}.{j.id}.log" ), j.dictionary_ids, feat_strings, j.construct_path_dictionary(self.working_directory, "ali", "ark"), self.model_path, self.lda_options, j.construct_path_dictionary(self.working_directory, "lda", "macc"), ) ) return arguments
@property def train_type(self) -> str: """Training identifier""" return "lda" @property def lda_options(self) -> MetaDict: """Options for computing LDA""" return { "lda_dimension": self.lda_dimension, "random_prune": self.random_prune, "silence_csl": self.silence_csl, "splice_left_context": self.splice_left_context, "splice_right_context": self.splice_right_context, }
[docs] def compute_calculated_properties(self) -> None: """Generate realignment iterations, MLLT estimation iterations, and initial gaussians based on configuration""" super().compute_calculated_properties() self.mllt_iterations = [2, 4, 6, 12]
[docs] def lda_acc_stats(self) -> None: """ Multiprocessing function that accumulates LDA statistics. See Also -------- :func:`~montreal_forced_aligner.acoustic_modeling.lda.LdaAccStatsFunction` Multiprocessing helper function for each job :meth:`.LdaTrainer.lda_acc_stats_arguments` Job method for generating arguments for the helper function :kaldi_src:`est-lda` Relevant Kaldi binary :kaldi_steps:`train_lda_mllt` Reference Kaldi script """ worker_lda_path = os.path.join(self.worker.working_directory, "lda.mat") lda_path = os.path.join(self.working_directory, "lda.mat") if os.path.exists(worker_lda_path): os.remove(worker_lda_path) arguments = self.lda_acc_stats_arguments() with tqdm.tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: if GLOBAL_CONFIG.use_mp: error_dict = {} return_queue = mp.Queue() stopped = Stopped() procs = [] for i, args in enumerate(arguments): function = LdaAccStatsFunction(args) p = KaldiProcessWorker(i, return_queue, function, stopped) procs.append(p) p.start() while True: try: result = return_queue.get(timeout=1) if isinstance(result, Exception): error_dict[getattr(result, "job_name", 0)] = result continue if stopped.stop_check(): continue except Empty: for proc in procs: if not proc.finished.stop_check(): break else: break continue done, errors = result pbar.update(done + errors) for p in procs: p.join() if error_dict: for v in error_dict.values(): raise v else: for args in arguments: function = LdaAccStatsFunction(args) for done, errors in function.run(): pbar.update(done + errors) log_path = os.path.join(self.working_log_directory, "lda_est.log") acc_list = [] for x in arguments: acc_list.extend(x.acc_paths.values()) with mfa_open(log_path, "w") as log_file: est_lda_proc = subprocess.Popen( [ thirdparty_binary("est-lda"), f"--dim={self.lda_dimension}", lda_path, ] + acc_list, stderr=log_file, env=os.environ, ) est_lda_proc.communicate() shutil.copyfile( lda_path, worker_lda_path, )
def _trainer_initialization(self) -> None: """Initialize LDA training""" self.uses_splices = True self.worker.uses_splices = True if self.initialized: return self.lda_acc_stats() self.tree_stats() self._setup_tree(initial_mix_up=False) self.compile_train_graphs() self.convert_alignments() os.rename(self.model_path, self.next_model_path)
[docs] def calc_lda_mllt(self) -> None: """ Multiprocessing function that calculates LDA+MLLT transformations. See Also -------- :func:`~montreal_forced_aligner.acoustic_modeling.lda.CalcLdaMlltFunction` Multiprocessing helper function for each job :meth:`.LdaTrainer.calc_lda_mllt_arguments` Job method for generating arguments for the helper function :kaldi_src:`est-mllt` Relevant Kaldi binary :kaldi_src:`gmm-transform-means` Relevant Kaldi binary :kaldi_src:`compose-transforms` Relevant Kaldi binary :kaldi_steps:`train_lda_mllt` Reference Kaldi script """ logger.info("Re-calculating LDA...") arguments = self.calc_lda_mllt_arguments() with tqdm.tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: if GLOBAL_CONFIG.use_mp: error_dict = {} return_queue = mp.Queue() stopped = Stopped() procs = [] for i, args in enumerate(arguments): function = CalcLdaMlltFunction(args) p = KaldiProcessWorker(i, return_queue, function, stopped) procs.append(p) p.start() while True: try: result = return_queue.get(timeout=1) if isinstance(result, Exception): error_dict[getattr(result, "job_name", 0)] = result continue if stopped.stop_check(): continue except Empty: for proc in procs: if not proc.finished.stop_check(): break else: break continue pbar.update(1) for p in procs: p.join() if error_dict: for v in error_dict.values(): raise v else: for args in arguments: function = CalcLdaMlltFunction(args) for _ in function.run(): pbar.update(1) log_path = os.path.join( self.working_log_directory, f"transform_means.{self.iteration}.log" ) previous_mat_path = os.path.join(self.working_directory, "lda.mat") new_mat_path = os.path.join(self.working_directory, "lda_new.mat") composed_path = os.path.join(self.working_directory, "lda_composed.mat") with mfa_open(log_path, "a") as log_file: macc_list = [] for x in arguments: macc_list.extend(x.macc_paths.values()) subprocess.call( [thirdparty_binary("est-mllt"), new_mat_path] + macc_list, stderr=log_file, env=os.environ, ) subprocess.call( [ thirdparty_binary("gmm-transform-means"), new_mat_path, self.model_path, self.model_path, ], stderr=log_file, env=os.environ, ) if os.path.exists(previous_mat_path): subprocess.call( [ thirdparty_binary("compose-transforms"), new_mat_path, previous_mat_path, composed_path, ], stderr=log_file, env=os.environ, ) os.remove(previous_mat_path) os.rename(composed_path, previous_mat_path) else: os.rename(new_mat_path, previous_mat_path)
[docs] def train_iteration(self) -> None: """ Run a single LDA training iteration """ if os.path.exists(self.next_model_path): if self.iteration <= self.final_gaussian_iteration: self.increment_gaussians() self.iteration += 1 return if self.iteration in self.realignment_iterations: self.align_iteration() if self.iteration in self.mllt_iterations: self.calc_lda_mllt() self.acc_stats() parse_logs(self.working_log_directory) if self.iteration <= self.final_gaussian_iteration: self.increment_gaussians() self.iteration += 1