Source code for montreal_forced_aligner.acoustic_modeling.triphone

"""Class definitions for TriphoneTrainer"""
from __future__ import annotations

import logging
import multiprocessing as mp
import os
import re
import subprocess
import typing
from pathlib import Path
from queue import Empty
from typing import TYPE_CHECKING, Dict, List

from tqdm.rich import tqdm

from montreal_forced_aligner.acoustic_modeling.base import AcousticModelTrainingMixin
from montreal_forced_aligner.config import GLOBAL_CONFIG
from montreal_forced_aligner.data import MfaArguments
from montreal_forced_aligner.helper import mfa_open
from montreal_forced_aligner.utils import (
    KaldiFunction,
    KaldiProcessWorker,
    Stopped,
    parse_logs,
    run_mp,
    run_non_mp,
    thirdparty_binary,
)

if TYPE_CHECKING:
    from ..abc import MetaDict


__all__ = [
    "TriphoneTrainer",
    "TreeStatsArguments",
    "ConvertAlignmentsFunction",
    "ConvertAlignmentsArguments",
]

logger = logging.getLogger("mfa")


[docs] class TreeStatsArguments(MfaArguments): """Arguments for :func:`~montreal_forced_aligner.acoustic_modeling.triphone.tree_stats_func`""" dictionaries: List[str] ci_phones: str model_path: Path feature_strings: Dict[str, str] ali_paths: Dict[str, Path] treeacc_paths: Dict[str, Path]
[docs] class ConvertAlignmentsArguments(MfaArguments): """Arguments for :func:`~montreal_forced_aligner.acoustic_modeling.triphone.ConvertAlignmentsFunction`""" dictionaries: List[str] model_path: Path tree_path: Path align_model_path: Path ali_paths: Dict[str, Path] new_ali_paths: Dict[str, Path]
[docs] class ConvertAlignmentsFunction(KaldiFunction): """ Multiprocessing function for converting alignments from a previous trainer See Also -------- :meth:`.TriphoneTrainer.convert_alignments` Main function that calls this function in parallel :meth:`.TriphoneTrainer.convert_alignments_arguments` Job method for generating arguments for this function :kaldi_src:`convert-ali` Relevant Kaldi binary Parameters ---------- args: :class:`~montreal_forced_aligner.acoustic_modeling.triphone.ConvertAlignmentsArguments` Arguments for the function """ progress_pattern = re.compile( r"^LOG.*Succeeded converting alignments for (?P<utterances>\d+) files, failed for (?P<failed>\d+)$" ) def __init__(self, args: ConvertAlignmentsArguments): super().__init__(args) self.dictionaries = args.dictionaries self.model_path = args.model_path self.tree_path = args.tree_path self.align_model_path = args.align_model_path self.ali_paths = args.ali_paths self.new_ali_paths = args.new_ali_paths def _run(self) -> typing.Generator[typing.Tuple[int, int]]: """Run the function""" with mfa_open(self.log_path, "w") as log_file: for dict_id in self.dictionaries: ali_path = self.ali_paths[dict_id] new_ali_path = self.new_ali_paths[dict_id] convert_proc = subprocess.Popen( [ thirdparty_binary("convert-ali"), self.align_model_path, self.model_path, self.tree_path, f"ark:{ali_path}", f"ark:{new_ali_path}", ], stderr=subprocess.PIPE, encoding="utf8", env=os.environ, ) for line in convert_proc.stderr: log_file.write(line) m = self.progress_pattern.match(line.strip()) if m: yield int(m.group("utterances")), int(m.group("failed")) self.check_call(convert_proc)
[docs] def tree_stats_func( arguments: TreeStatsArguments, ) -> None: """ Multiprocessing function for calculating tree stats for training See Also -------- :meth:`.TriphoneTrainer.tree_stats` Main function that calls this function in parallel :meth:`.TriphoneTrainer.tree_stats_arguments` Job method for generating arguments for this function :kaldi_src:`acc-tree-stats` Relevant Kaldi binary Parameters ---------- arguments: TreeStatsArguments Arguments for the function """ with mfa_open(arguments.log_path, "w") as log_file: for dict_id in arguments.dictionaries: feature_string = arguments.feature_strings[dict_id] ali_path = arguments.ali_paths[dict_id] treeacc_path = arguments.treeacc_paths[dict_id] subprocess.call( [ thirdparty_binary("acc-tree-stats"), f"--ci-phones={arguments.ci_phones}", arguments.model_path, feature_string, f"ark:{ali_path}", treeacc_path, ], stderr=log_file, )
[docs] class TriphoneTrainer(AcousticModelTrainingMixin): """ Triphone trainer Parameters ---------- subset : int Number of utterances to use, defaults to 5000 num_iterations : int Number of training iterations to perform, defaults to 35 num_leaves : int Number of states in the decision tree, defaults to 1000 max_gaussians : int Number of gaussians in the decision tree, defaults to 10000 cluster_threshold : int For build-tree control final bottom-up clustering of leaves, defaults to 100 See Also -------- :class:`~montreal_forced_aligner.acoustic_modeling.base.AcousticModelTrainingMixin` For acoustic model training parsing parameters """ def __init__( self, subset: int = 5000, num_iterations: int = 35, num_leaves: int = 1000, max_gaussians: int = 10000, cluster_threshold: int = -1, boost_silence: float = 1.0, power: float = 0.25, **kwargs, ): super().__init__( num_iterations=num_iterations, boost_silence=boost_silence, power=power, subset=subset, initial_gaussians=num_leaves, max_gaussians=max_gaussians, **kwargs, ) self.num_leaves = num_leaves self.cluster_threshold = cluster_threshold
[docs] def tree_stats_arguments(self) -> List[TreeStatsArguments]: """ Generate Job arguments for :func:`~montreal_forced_aligner.acoustic_modeling.triphone.tree_stats_func` Returns ------- list[:class:`~montreal_forced_aligner.acoustic_modeling.triphone.TreeStatsArguments`] Arguments for processing """ alignment_model_path = os.path.join(self.previous_aligner.working_directory, "final.mdl") arguments = [] for j in self.jobs: feat_strings = {} ali_paths = {} treeacc_paths = {} for d_id in j.dictionary_ids: feat_strings[d_id] = j.construct_feature_proc_string( self.working_directory, d_id, self.feature_options["uses_splices"], self.feature_options["splice_left_context"], self.feature_options["splice_right_context"], self.feature_options["uses_speaker_adaptation"], ) ali_paths[d_id] = j.construct_path( self.previous_aligner.working_directory, "ali", "ark", d_id ) treeacc_paths[d_id] = j.construct_path(self.working_directory, "tree", "acc", d_id) arguments.append( TreeStatsArguments( j.id, getattr(self, "db_string", ""), self.working_log_directory.joinpath(f"acc_tree.{j.id}.log"), j.dictionary_ids, self.worker.context_independent_csl, alignment_model_path, feat_strings, ali_paths, treeacc_paths, ) ) return arguments
[docs] def convert_alignments_arguments(self) -> List[ConvertAlignmentsArguments]: """ Generate Job arguments for :func:`~montreal_forced_aligner.acoustic_modeling.triphone.ConvertAlignmentsFunction` Returns ------- list[:class:`~montreal_forced_aligner.acoustic_modeling.triphone.ConvertAlignmentsArguments`] Arguments for processing """ return [ ConvertAlignmentsArguments( j.id, getattr(self, "db_string", ""), self.working_log_directory.joinpath(f"convert_alignments.{j.id}.log"), j.dictionary_ids, self.model_path, self.tree_path, self.previous_aligner.model_path, j.construct_path_dictionary(self.previous_aligner.working_directory, "ali", "ark"), j.construct_path_dictionary(self.working_directory, "ali", "ark"), ) for j in self.jobs ]
[docs] def convert_alignments(self) -> None: """ Multiprocessing function that converts alignments from previous training See Also -------- :func:`~montreal_forced_aligner.acoustic_modeling.triphone.ConvertAlignmentsFunction` Multiprocessing helper function for each job :meth:`.TriphoneTrainer.convert_alignments_arguments` Job method for generating arguments for the helper function :kaldi_steps:`train_deltas` Reference Kaldi script :kaldi_steps:`train_lda_mllt` Reference Kaldi script :kaldi_steps:`train_sat` Reference Kaldi script """ logger.info("Converting alignments...") arguments = self.convert_alignments_arguments() with tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: if GLOBAL_CONFIG.use_mp: error_dict = {} return_queue = mp.Queue() stopped = Stopped() procs = [] for i, args in enumerate(arguments): function = ConvertAlignmentsFunction(args) p = KaldiProcessWorker(i, return_queue, function, stopped) procs.append(p) p.start() while True: try: result = return_queue.get(timeout=1) if isinstance(result, Exception): error_dict[getattr(result, "job_name", 0)] = result continue if stopped.stop_check(): continue except Empty: for proc in procs: if not proc.finished.stop_check(): break else: break continue num_utterances, errors = result pbar.update(num_utterances + errors) for p in procs: p.join() if error_dict: for v in error_dict.values(): raise v else: for args in arguments: function = ConvertAlignmentsFunction(args) for num_utterances, errors in function.run(): pbar.update(num_utterances + errors)
[docs] def acoustic_model_training_params(self) -> MetaDict: """Configuration parameters""" return { "num_iterations": self.num_iterations, "num_leaves": self.num_leaves, "max_gaussians": self.max_gaussians, "cluster_threshold": self.cluster_threshold, }
[docs] def compute_calculated_properties(self) -> None: """Generate realignment iterations and initial gaussians based on configuration""" for i in range(0, self.num_iterations, 10): if i == 0: continue self.realignment_iterations.append(i) self.initial_gaussians = self.num_leaves self.final_gaussian_iteration = self.num_iterations - 10
@property def train_type(self) -> str: """Training identifier""" return "tri" @property def phone_type(self) -> str: """Phone type""" return "triphone" def _trainer_initialization(self) -> None: """Triphone training initialization""" if self.initialized: return self.tree_stats() self._setup_tree() self.compile_train_graphs() self.convert_alignments() os.rename(self.model_path, self.next_model_path)
[docs] def tree_stats(self) -> None: """ Multiprocessing function that computes stats for decision tree training. See Also -------- :func:`~montreal_forced_aligner.acoustic_modeling.triphone.tree_stats_func` Multiprocessing helper function for each job :meth:`.TriphoneTrainer.tree_stats_arguments` Job method for generating arguments for the helper function :kaldi_src:`sum-tree-stats` Relevant Kaldi binary :kaldi_steps:`train_deltas` Reference Kaldi script :kaldi_steps:`train_lda_mllt` Reference Kaldi script :kaldi_steps:`train_sat` Reference Kaldi script """ jobs = self.tree_stats_arguments() if GLOBAL_CONFIG.use_mp: run_mp(tree_stats_func, jobs, self.working_log_directory) else: run_non_mp(tree_stats_func, jobs, self.working_log_directory) tree_accs = [] for x in jobs: tree_accs.extend(x.treeacc_paths.values()) log_path = self.working_log_directory.joinpath("sum_tree_acc.log") with mfa_open(log_path, "w") as log_file: subprocess.call( [ thirdparty_binary("sum-tree-stats"), self.working_directory.joinpath("treeacc"), ] + tree_accs, stderr=log_file, ) if not GLOBAL_CONFIG.debug: for f in tree_accs: os.remove(f)
def _setup_tree(self, init_from_previous=False, initial_mix_up=True) -> None: """ Set up the tree for the triphone model Raises ------ :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` If there were any errors in running Kaldi binaries """ log_path = self.working_log_directory.joinpath("questions.log") tree_path = self.working_directory.joinpath("tree") treeacc_path = self.working_directory.joinpath("treeacc") sets_int_path = os.path.join(self.worker.phones_dir, "sets.int") roots_int_path = os.path.join(self.worker.phones_dir, "roots.int") extra_question_int_path = os.path.join(self.worker.phones_dir, "extra_questions.int") topo_path = self.worker.topo_path questions_path = self.working_directory.joinpath("questions.int") questions_qst_path = self.working_directory.joinpath("questions.qst") with mfa_open(log_path, "w") as log_file: subprocess.call( [ thirdparty_binary("cluster-phones"), treeacc_path, sets_int_path, questions_path, ], stderr=log_file, ) with mfa_open(extra_question_int_path, "r") as inf, mfa_open(questions_path, "a") as outf: for line in inf: outf.write(line) log_path = self.working_log_directory.joinpath("compile_questions.log") with mfa_open(log_path, "w") as log_file: subprocess.call( [ thirdparty_binary("compile-questions"), topo_path, questions_path, questions_qst_path, ], stderr=log_file, ) log_path = self.working_log_directory.joinpath("build_tree.log") with mfa_open(log_path, "w") as log_file: subprocess.call( [ thirdparty_binary("build-tree"), "--verbose=1", f"--max-leaves={self.num_leaves}", f"--cluster-thresh={self.cluster_threshold}", treeacc_path, roots_int_path, questions_qst_path, topo_path, tree_path, ], stderr=log_file, ) log_path = self.working_log_directory.joinpath("init_model.log") occs_path = self.working_directory.joinpath("0.occs") mdl_path = self.model_path if init_from_previous: command = [ thirdparty_binary("gmm-init-model"), f"--write-occs={occs_path}", tree_path, treeacc_path, topo_path, mdl_path, os.path.join(self.previous_aligner.working_directory, "tree"), os.path.join(self.previous_aligner.working_directory, "final.mdl"), ] else: command = [ thirdparty_binary("gmm-init-model"), f"--write-occs={occs_path}", tree_path, treeacc_path, topo_path, mdl_path, ] with mfa_open(log_path, "w") as log_file: subprocess.call(command, stderr=log_file) if initial_mix_up: if init_from_previous: command = [ thirdparty_binary("gmm-mixup"), f"--mix-up={self.initial_gaussians}", f"--mix-down={self.initial_gaussians}", mdl_path, occs_path, mdl_path, ] else: command = [ thirdparty_binary("gmm-mixup"), f"--mix-up={self.initial_gaussians}", mdl_path, occs_path, mdl_path, ] log_path = self.working_log_directory.joinpath("mixup.log") with mfa_open(log_path, "w") as log_file: subprocess.call(command, stderr=log_file) os.remove(treeacc_path) os.rename(occs_path, self.next_occs_path) parse_logs(self.working_log_directory)