Source code for montreal_forced_aligner.acoustic_modeling.monophone

"""Class definitions for Monophone trainer"""
from __future__ import annotations

import logging
import multiprocessing as mp
import os
import re
import subprocess
import typing
from pathlib import Path
from queue import Empty

from sqlalchemy.orm import Session, joinedload, subqueryload
from tqdm.rich import tqdm

from montreal_forced_aligner.abc import KaldiFunction
from montreal_forced_aligner.acoustic_modeling.base import AcousticModelTrainingMixin
from montreal_forced_aligner.config import GLOBAL_CONFIG
from montreal_forced_aligner.data import MfaArguments
from montreal_forced_aligner.db import CorpusWorkflow, Job
from montreal_forced_aligner.exceptions import KaldiProcessingError
from montreal_forced_aligner.helper import mfa_open
from montreal_forced_aligner.utils import KaldiProcessWorker, Stopped, thirdparty_binary

if typing.TYPE_CHECKING:
    from montreal_forced_aligner.abc import MetaDict

__all__ = ["MonophoneTrainer", "MonoAlignEqualFunction", "MonoAlignEqualArguments"]

logger = logging.getLogger("mfa")


[docs] class MonoAlignEqualArguments(MfaArguments): """Arguments for :func:`~montreal_forced_aligner.acoustic_modeling.monophone.MonoAlignEqualFunction`""" model_path: Path feature_options: MetaDict
[docs] class MonoAlignEqualFunction(KaldiFunction): """ Multiprocessing function for initializing monophone alignments See Also -------- :meth:`.MonophoneTrainer.mono_align_equal` Main function that calls this function in parallel :meth:`.MonophoneTrainer.mono_align_equal_arguments` Job method for generating arguments for this function :kaldi_src:`align-equal-compiled` Relevant Kaldi binary :kaldi_src:`gmm-acc-stats-ali` Relevant Kaldi binary Parameters ---------- args: :class:`~montreal_forced_aligner.acoustic_modeling.monophone.MonoAlignEqualArguments` Arguments for the function """ progress_pattern = re.compile( r"^LOG.* Done (?P<utterances>\d+) files, (?P<errors>\d+) with errors.$" ) def __init__(self, args: MonoAlignEqualArguments): super().__init__(args) self.model_path = args.model_path self.feature_options = args.feature_options def _run(self) -> typing.Generator[typing.Tuple[int, int]]: """Run the function""" with mfa_open(self.log_path, "w") as log_file, Session(self.db_engine()) as session: job = ( session.query(Job) .options(joinedload(Job.corpus, innerjoin=True), subqueryload(Job.dictionaries)) .filter(Job.id == self.job_name) .first() ) workflow: CorpusWorkflow = ( session.query(CorpusWorkflow) .filter(CorpusWorkflow.current == True) # noqa .first() ) for dict_id in job.dictionary_ids: feature_string = job.construct_feature_proc_string( workflow.working_directory, dict_id, self.feature_options["uses_splices"], self.feature_options["splice_left_context"], self.feature_options["splice_right_context"], self.feature_options["uses_speaker_adaptation"], ) fst_ark_path = job.construct_path( workflow.working_directory, "fsts", "ark", dict_id ) ali_path = job.construct_path(workflow.working_directory, "ali", "ark", dict_id) acc_path = job.construct_path(workflow.working_directory, "0", "acc", dict_id) align_proc = subprocess.Popen( [ thirdparty_binary("align-equal-compiled"), f"ark:{fst_ark_path}", feature_string, f"ark:{ali_path}", ], stderr=log_file, env=os.environ, ) align_proc.communicate() acc_proc = subprocess.Popen( [ thirdparty_binary("gmm-acc-stats-ali"), "--binary=true", self.model_path, feature_string, f"ark:{ali_path}", acc_path, ], stdin=align_proc.stdout, stderr=subprocess.PIPE, encoding="utf8", env=os.environ, ) for line in acc_proc.stderr: log_file.write(line) m = self.progress_pattern.match(line.strip()) if m: yield int(m.group("utterances")), int(m.group("errors")) self.check_call(acc_proc)
[docs] class MonophoneTrainer(AcousticModelTrainingMixin): """ Configuration class for monophone training Attributes ---------- subset : int Number of utterances to use, defaults to 2000 initial_gaussians : int Number of gaussians to begin training, defaults to 135 max_gaussians : int Total number of gaussians, defaults to 1000 power : float Exponent for number of gaussians according to occurrence counts, defaults to 0.25 See Also -------- :class:`~montreal_forced_aligner.acoustic_modeling.base.AcousticModelTrainingMixin` For acoustic model training parsing parameters """ def __init__( self, subset: int = 2000, initial_gaussians: int = 135, initial_beam: int = 6, max_gaussians: int = 1000, power: float = 0.25, **kwargs, ): super().__init__(**kwargs) self.subset = subset self.initial_gaussians = initial_gaussians self.initial_beam = initial_beam self.max_gaussians = max_gaussians self.power = power self.last_gaussian_increase_iteration = 0
[docs] def mono_align_equal_arguments(self) -> typing.List[MonoAlignEqualArguments]: """ Generate Job arguments for :func:`~montreal_forced_aligner.acoustic_modeling.monophone.MonoAlignEqualFunction` Returns ------- list[:class:`~montreal_forced_aligner.acoustic_modeling.monophone.MonoAlignEqualArguments`] Arguments for processing """ return [ MonoAlignEqualArguments( j.id, getattr(self, "db_string", ""), self.working_log_directory.joinpath(f"mono_align_equal.{j.id}.log"), self.model_path, self.feature_options, ) for j in self.jobs ]
[docs] def compute_calculated_properties(self) -> None: """Generate realignment iterations and initial gaussians based on configuration""" self.final_gaussian_iteration = self.num_iterations - 10 self.realignment_iterations = [0] for i in range(1, self.num_iterations): if i <= int(self.num_iterations / 4): self.realignment_iterations.append(i) elif i <= int(self.num_iterations * 2 / 4): if i - self.realignment_iterations[-1] > 1: self.realignment_iterations.append(i) else: if i - self.realignment_iterations[-1] > 2: self.realignment_iterations.append(i)
@property def train_type(self) -> str: """Training identifier""" return "mono" @property def phone_type(self) -> str: """Phone type""" return "monophone" @property def align_options(self) -> MetaDict: """Alignment parameters""" options = super().align_options if self.iteration == 1: options["beam"] = self.initial_beam return options
[docs] def mono_align_equal(self) -> None: """ Multiprocessing function that creates equal alignments for base monophone training. See Also -------- :func:`~montreal_forced_aligner.acoustic_modeling.monophone.MonoAlignEqualFunction` Multiprocessing helper function for each job :meth:`.MonophoneTrainer.mono_align_equal_arguments` Job method for generating arguments for the helper function :kaldi_src:`gmm-sum-accs` Relevant Kaldi binary :kaldi_src:`gmm-est` Relevant Kaldi binary :kaldi_steps:`train_mono` Reference Kaldi script """ logger.info("Generating initial alignments...") arguments = self.mono_align_equal_arguments() with tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar: if GLOBAL_CONFIG.use_mp: error_dict = {} return_queue = mp.Queue() stopped = Stopped() procs = [] for i, args in enumerate(arguments): function = MonoAlignEqualFunction(args) p = KaldiProcessWorker(i, return_queue, function, stopped) procs.append(p) p.start() while True: try: result = return_queue.get(timeout=1) if isinstance(result, Exception): error_dict[getattr(result, "job_name", 0)] = result continue if stopped.stop_check(): continue except Empty: for proc in procs: if not proc.finished.stop_check(): break else: break continue num_utterances, errors = result pbar.update(num_utterances + errors) for p in procs: p.join() if error_dict: error_logs = [] for e in error_dict.values(): if isinstance(e, KaldiProcessingError): error_logs.extend(e.error_logs) else: raise e if error_logs: e = KaldiProcessingError(e.error_logs) e.update_log_file() raise e else: for args in arguments: function = MonoAlignEqualFunction(args) for num_utterances, errors in function.run(): pbar.update(num_utterances + errors) log_path = self.working_log_directory.joinpath("update.0.log") with mfa_open(log_path, "w") as log_file: acc_files = [] for j in self.jobs: for dict_id in j.dictionary_ids: acc_files.append(j.construct_path(self.working_directory, "0", "acc", dict_id)) sum_proc = subprocess.Popen( [thirdparty_binary("gmm-sum-accs"), "-"] + acc_files, stderr=log_file, stdout=subprocess.PIPE, env=os.environ, ) est_proc = subprocess.Popen( [ thirdparty_binary("gmm-est"), "--min-gaussian-occupancy=3", f"--mix-up={self.current_gaussians}", f"--power={self.power}", self.model_path, "-", self.next_model_path, ], stderr=log_file, stdin=sum_proc.stdout, env=os.environ, ) est_proc.communicate() if est_proc.returncode != 0: raise KaldiProcessingError([log_path]) if not GLOBAL_CONFIG.debug: for f in acc_files: os.remove(f)
def _trainer_initialization(self) -> None: """Monophone training initialization""" if self.initialized: return self.iteration = 0 tree_path = self.working_directory.joinpath("tree") feat_dim = self.worker.get_feat_dim() feature_string = self.jobs[0].construct_feature_proc_string( self.working_directory, self.jobs[0].dictionary_ids[0], self.feature_options["uses_splices"], self.feature_options["splice_left_context"], self.feature_options["splice_right_context"], self.feature_options["uses_speaker_adaptation"], ) shared_phones_path = os.path.join(self.worker.phones_dir, "sets.int") init_log_path = self.working_log_directory.joinpath("init.log") temp_feats_path = self.working_directory.joinpath("temp_feats") with mfa_open(init_log_path, "w") as log_file: subprocess.call( [ thirdparty_binary("subset-feats"), "--n=10", feature_string, f"ark:{temp_feats_path}", ], stderr=log_file, ) subprocess.call( [ thirdparty_binary("gmm-init-mono"), f"--shared-phones={shared_phones_path}", f"--train-feats=ark:{temp_feats_path}", os.path.join(self.worker.topo_path), str(feat_dim), self.model_path, tree_path, ], stderr=log_file, ) proc = subprocess.Popen( [thirdparty_binary("gmm-info"), "--print-args=false", self.model_path], stderr=log_file, stdout=subprocess.PIPE, encoding="utf8", ) stdout, stderr = proc.communicate() if proc.returncode != 0: raise KaldiProcessingError([init_log_path]) matches = re.search(r"gaussians (\d+)", stdout) num_gauss = int(matches.groups()[0]) os.remove(temp_feats_path) self.initial_gaussians = num_gauss self.current_gaussians = num_gauss if os.path.exists(self.model_path): os.remove( init_log_path ) # Has some errors related to subsetting that trigger larger failures self.compile_train_graphs() self.mono_align_equal()