"""Class definitions for LDA trainer"""
from __future__ import annotations
import logging
import multiprocessing as mp
import os
import re
import shutil
import subprocess
import typing
from pathlib import Path
from queue import Empty
from typing import TYPE_CHECKING, Dict, List
from tqdm.rich import tqdm
from montreal_forced_aligner.abc import KaldiFunction
from montreal_forced_aligner.acoustic_modeling.triphone import TriphoneTrainer
from montreal_forced_aligner.config import GLOBAL_CONFIG
from montreal_forced_aligner.data import MfaArguments
from montreal_forced_aligner.helper import mfa_open
from montreal_forced_aligner.utils import (
KaldiProcessWorker,
Stopped,
parse_logs,
thirdparty_binary,
)
if TYPE_CHECKING:
from montreal_forced_aligner.abc import MetaDict
__all__ = [
"LdaTrainer",
"CalcLdaMlltFunction",
"CalcLdaMlltArguments",
"LdaAccStatsFunction",
"LdaAccStatsArguments",
]
logger = logging.getLogger("mfa")
[docs]
class LdaAccStatsArguments(MfaArguments):
"""Arguments for :func:`~montreal_forced_aligner.acoustic_modeling.lda.LdaAccStatsFunction`"""
dictionaries: List[str]
feature_strings: Dict[str, str]
ali_paths: Dict[str, Path]
model_path: Path
lda_options: MetaDict
acc_paths: Dict[str, Path]
[docs]
class CalcLdaMlltArguments(MfaArguments):
"""Arguments for :func:`~montreal_forced_aligner.acoustic_modeling.lda.CalcLdaMlltFunction`"""
dictionaries: List[str]
feature_strings: Dict[str, str]
ali_paths: Dict[str, Path]
model_path: Path
lda_options: MetaDict
macc_paths: Dict[str, Path]
[docs]
class LdaAccStatsFunction(KaldiFunction):
"""
Multiprocessing function to accumulate LDA stats
See Also
--------
:meth:`.LdaTrainer.lda_acc_stats`
Main function that calls this function in parallel
:meth:`.LdaTrainer.lda_acc_stats_arguments`
Job method for generating arguments for this function
:kaldi_src:`ali-to-post`
Relevant Kaldi binary
:kaldi_src:`weight-silence-post`
Relevant Kaldi binary
:kaldi_src:`acc-lda`
Relevant Kaldi binary
Parameters
----------
args: :class:`~montreal_forced_aligner.acoustic_modeling.lda.LdaAccStatsArguments`
Arguments for the function
"""
progress_pattern = re.compile(r"^LOG.*Done (?P<done>\d+) files, failed for (?P<failed>\d+)$")
def __init__(self, args: LdaAccStatsArguments):
super().__init__(args)
self.dictionaries = args.dictionaries
self.feature_strings = args.feature_strings
self.ali_paths = args.ali_paths
self.model_path = args.model_path
self.acc_paths = args.acc_paths
self.lda_options = args.lda_options
def _run(self) -> typing.Generator[typing.Tuple[int, int]]:
"""Run the function"""
with mfa_open(self.log_path, "w") as log_file:
for dict_id in self.dictionaries:
ali_path = self.ali_paths[dict_id]
feature_string = self.feature_strings[dict_id]
acc_path = self.acc_paths[dict_id]
ali_to_post_proc = subprocess.Popen(
[thirdparty_binary("ali-to-post"), f"ark:{ali_path}", "ark:-"],
stderr=log_file,
stdout=subprocess.PIPE,
env=os.environ,
)
weight_silence_post_proc = subprocess.Popen(
[
thirdparty_binary("weight-silence-post"),
"0.0",
self.lda_options["silence_csl"],
self.model_path,
"ark:-",
"ark:-",
],
stdin=ali_to_post_proc.stdout,
stderr=log_file,
stdout=subprocess.PIPE,
env=os.environ,
)
acc_lda_post_proc = subprocess.Popen(
[
thirdparty_binary("acc-lda"),
f"--rand-prune={self.lda_options['random_prune']}",
self.model_path,
feature_string,
"ark,s,cs:-",
acc_path,
],
stdin=weight_silence_post_proc.stdout,
stderr=subprocess.PIPE,
encoding="utf8",
env=os.environ,
)
for line in acc_lda_post_proc.stderr:
log_file.write(line)
m = self.progress_pattern.match(line.strip())
if m:
yield int(m.group("done")), int(m.group("failed"))
self.check_call(acc_lda_post_proc)
[docs]
class CalcLdaMlltFunction(KaldiFunction):
"""
Multiprocessing function for estimating LDA with MLLT.
See Also
--------
:meth:`.LdaTrainer.calc_lda_mllt`
Main function that calls this function in parallel
:meth:`.LdaTrainer.calc_lda_mllt_arguments`
Job method for generating arguments for this function
:kaldi_src:`ali-to-post`
Relevant Kaldi binary
:kaldi_src:`weight-silence-post`
Relevant Kaldi binary
:kaldi_src:`gmm-acc-mllt`
Relevant Kaldi binary
Parameters
----------
args: :class:`~montreal_forced_aligner.acoustic_modeling.lda.CalcLdaMlltArguments`
Arguments for the function
"""
progress_pattern = re.compile(r"^LOG.*Average like for this file.*$")
def __init__(self, args: CalcLdaMlltArguments):
super().__init__(args)
self.dictionaries = args.dictionaries
self.feature_strings = args.feature_strings
self.ali_paths = args.ali_paths
self.model_path = args.model_path
self.macc_paths = args.macc_paths
self.lda_options = args.lda_options
def _run(self) -> typing.Generator[int]:
"""Run the function"""
# Estimating MLLT
with mfa_open(self.log_path, "w") as log_file:
for dict_id in self.dictionaries:
ali_path = self.ali_paths[dict_id]
feature_string = self.feature_strings[dict_id]
macc_path = self.macc_paths[dict_id]
post_proc = subprocess.Popen(
[thirdparty_binary("ali-to-post"), f"ark:{ali_path}", "ark:-"],
stdout=subprocess.PIPE,
stderr=log_file,
env=os.environ,
)
weight_proc = subprocess.Popen(
[
thirdparty_binary("weight-silence-post"),
"0.0",
self.lda_options["silence_csl"],
self.model_path,
"ark:-",
"ark:-",
],
stdin=post_proc.stdout,
stdout=subprocess.PIPE,
stderr=log_file,
env=os.environ,
)
acc_proc = subprocess.Popen(
[
thirdparty_binary("gmm-acc-mllt"),
f"--rand-prune={self.lda_options['random_prune']}",
self.model_path,
feature_string,
"ark,s,cs:-",
macc_path,
],
stdin=weight_proc.stdout,
stderr=subprocess.PIPE,
encoding="utf8",
env=os.environ,
)
for line in acc_proc.stderr:
log_file.write(line)
m = self.progress_pattern.match(line.strip())
if m:
yield 1
self.check_call(acc_proc)
[docs]
class LdaTrainer(TriphoneTrainer):
"""
Triphone trainer
Parameters
----------
subset : int
Number of utterances to use, defaults to 10000
num_leaves : int
Number of states in the decision tree, defaults to 2500
max_gaussians : int
Number of gaussians in the decision tree, defaults to 15000
lda_dimension : int
Dimensionality of the LDA matrix
uses_splices : bool
Flag to use spliced and LDA calculation
splice_left_context : int or None
Number of frames to splice on the left for calculating LDA
splice_right_context : int or None
Number of frames to splice on the right for calculating LDA
random_prune : float
This is approximately the ratio by which we will speed up the
LDA and MLLT calculations via randomized pruning
See Also
--------
:class:`~montreal_forced_aligner.acoustic_modeling.triphone.TriphoneTrainer`
For acoustic model training parsing parameters
Attributes
----------
mllt_iterations : list
List of iterations to perform MLLT estimation
"""
def __init__(
self,
subset: int = 10000,
num_leaves: int = 2500,
max_gaussians=15000,
lda_dimension: int = 40,
uses_splices: bool = True,
splice_left_context: int = 3,
splice_right_context: int = 3,
random_prune=4.0,
boost_silence: float = 1.0,
power: float = 0.25,
**kwargs,
):
super().__init__(
boost_silence=boost_silence,
power=power,
subset=subset,
num_leaves=num_leaves,
max_gaussians=max_gaussians,
**kwargs,
)
self.lda_dimension = lda_dimension
self.random_prune = random_prune
self.uses_splices = uses_splices
self.splice_left_context = splice_left_context
self.splice_right_context = splice_right_context
[docs]
def lda_acc_stats_arguments(self) -> List[LdaAccStatsArguments]:
"""
Generate Job arguments for :func:`~montreal_forced_aligner.acoustic_modeling.lda.LdaAccStatsFunction`
Returns
-------
list[:class:`~montreal_forced_aligner.acoustic_modeling.lda.LdaAccStatsArguments`]
Arguments for processing
"""
arguments = []
for j in self.jobs:
feat_strings = {}
for d_id in j.dictionary_ids:
feat_strings[d_id] = j.construct_feature_proc_string(
self.working_directory,
d_id,
self.feature_options["uses_splices"],
self.feature_options["splice_left_context"],
self.feature_options["splice_right_context"],
self.feature_options["uses_speaker_adaptation"],
)
arguments.append(
LdaAccStatsArguments(
j.id,
getattr(self, "db_string", ""),
self.working_log_directory.joinpath(f"lda_acc_stats.{j.id}.log"),
j.dictionary_ids,
feat_strings,
j.construct_path_dictionary(
self.previous_aligner.working_directory, "ali", "ark"
),
self.previous_aligner.alignment_model_path,
self.lda_options,
j.construct_path_dictionary(self.working_directory, "lda", "acc"),
)
)
return arguments
[docs]
def calc_lda_mllt_arguments(self) -> List[CalcLdaMlltArguments]:
"""
Generate Job arguments for :func:`~montreal_forced_aligner.acoustic_modeling.lda.CalcLdaMlltFunction`
Returns
-------
list[:class:`~montreal_forced_aligner.acoustic_modeling.lda.CalcLdaMlltArguments`]
Arguments for processing
"""
arguments = []
for j in self.jobs:
feat_strings = {}
for d_id in j.dictionary_ids:
feat_strings[d_id] = j.construct_feature_proc_string(
self.working_directory,
d_id,
self.feature_options["uses_splices"],
self.feature_options["splice_left_context"],
self.feature_options["splice_right_context"],
self.feature_options["uses_speaker_adaptation"],
)
arguments.append(
CalcLdaMlltArguments(
j.id,
getattr(self, "db_string", ""),
os.path.join(
self.working_log_directory, f"lda_mllt.{self.iteration}.{j.id}.log"
),
j.dictionary_ids,
feat_strings,
j.construct_path_dictionary(self.working_directory, "ali", "ark"),
self.model_path,
self.lda_options,
j.construct_path_dictionary(self.working_directory, "lda", "macc"),
)
)
return arguments
@property
def train_type(self) -> str:
"""Training identifier"""
return "lda"
@property
def lda_options(self) -> MetaDict:
"""Options for computing LDA"""
return {
"lda_dimension": self.lda_dimension,
"random_prune": self.random_prune,
"silence_csl": self.silence_csl,
"splice_left_context": self.splice_left_context,
"splice_right_context": self.splice_right_context,
}
[docs]
def compute_calculated_properties(self) -> None:
"""Generate realignment iterations, MLLT estimation iterations, and initial gaussians based on configuration"""
super().compute_calculated_properties()
self.mllt_iterations = [2, 4, 6, 12]
[docs]
def lda_acc_stats(self) -> None:
"""
Multiprocessing function that accumulates LDA statistics.
See Also
--------
:func:`~montreal_forced_aligner.acoustic_modeling.lda.LdaAccStatsFunction`
Multiprocessing helper function for each job
:meth:`.LdaTrainer.lda_acc_stats_arguments`
Job method for generating arguments for the helper function
:kaldi_src:`est-lda`
Relevant Kaldi binary
:kaldi_steps:`train_lda_mllt`
Reference Kaldi script
"""
worker_lda_path = os.path.join(self.worker.working_directory, "lda.mat")
lda_path = self.working_directory.joinpath("lda.mat")
if os.path.exists(worker_lda_path):
os.remove(worker_lda_path)
arguments = self.lda_acc_stats_arguments()
with tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar:
if GLOBAL_CONFIG.use_mp:
error_dict = {}
return_queue = mp.Queue()
stopped = Stopped()
procs = []
for i, args in enumerate(arguments):
function = LdaAccStatsFunction(args)
p = KaldiProcessWorker(i, return_queue, function, stopped)
procs.append(p)
p.start()
while True:
try:
result = return_queue.get(timeout=1)
if isinstance(result, Exception):
error_dict[getattr(result, "job_name", 0)] = result
continue
if stopped.stop_check():
continue
except Empty:
for proc in procs:
if not proc.finished.stop_check():
break
else:
break
continue
done, errors = result
pbar.update(done + errors)
for p in procs:
p.join()
if error_dict:
for v in error_dict.values():
raise v
else:
for args in arguments:
function = LdaAccStatsFunction(args)
for done, errors in function.run():
pbar.update(done + errors)
log_path = self.working_log_directory.joinpath("lda_est.log")
acc_list = []
for x in arguments:
acc_list.extend(x.acc_paths.values())
with mfa_open(log_path, "w") as log_file:
est_lda_proc = subprocess.Popen(
[
thirdparty_binary("est-lda"),
f"--dim={self.lda_dimension}",
lda_path,
]
+ acc_list,
stderr=log_file,
env=os.environ,
)
est_lda_proc.communicate()
shutil.copyfile(
lda_path,
worker_lda_path,
)
def _trainer_initialization(self) -> None:
"""Initialize LDA training"""
self.uses_splices = True
self.worker.uses_splices = True
if self.initialized:
return
self.lda_acc_stats()
self.tree_stats()
self._setup_tree(initial_mix_up=False)
self.compile_train_graphs()
self.convert_alignments()
os.rename(self.model_path, self.next_model_path)
[docs]
def calc_lda_mllt(self) -> None:
"""
Multiprocessing function that calculates LDA+MLLT transformations.
See Also
--------
:func:`~montreal_forced_aligner.acoustic_modeling.lda.CalcLdaMlltFunction`
Multiprocessing helper function for each job
:meth:`.LdaTrainer.calc_lda_mllt_arguments`
Job method for generating arguments for the helper function
:kaldi_src:`est-mllt`
Relevant Kaldi binary
:kaldi_src:`gmm-transform-means`
Relevant Kaldi binary
:kaldi_src:`compose-transforms`
Relevant Kaldi binary
:kaldi_steps:`train_lda_mllt`
Reference Kaldi script
"""
logger.info("Re-calculating LDA...")
arguments = self.calc_lda_mllt_arguments()
with tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar:
if GLOBAL_CONFIG.use_mp:
error_dict = {}
return_queue = mp.Queue()
stopped = Stopped()
procs = []
for i, args in enumerate(arguments):
function = CalcLdaMlltFunction(args)
p = KaldiProcessWorker(i, return_queue, function, stopped)
procs.append(p)
p.start()
while True:
try:
result = return_queue.get(timeout=1)
if isinstance(result, Exception):
error_dict[getattr(result, "job_name", 0)] = result
continue
if stopped.stop_check():
continue
except Empty:
for proc in procs:
if not proc.finished.stop_check():
break
else:
break
continue
pbar.update(1)
for p in procs:
p.join()
if error_dict:
for v in error_dict.values():
raise v
else:
for args in arguments:
function = CalcLdaMlltFunction(args)
for _ in function.run():
pbar.update(1)
log_path = os.path.join(
self.working_log_directory, f"transform_means.{self.iteration}.log"
)
previous_mat_path = self.working_directory.joinpath("lda.mat")
new_mat_path = self.working_directory.joinpath("lda_new.mat")
composed_path = self.working_directory.joinpath("lda_composed.mat")
with mfa_open(log_path, "a") as log_file:
macc_list = []
for x in arguments:
macc_list.extend(x.macc_paths.values())
subprocess.call(
[thirdparty_binary("est-mllt"), new_mat_path] + macc_list,
stderr=log_file,
env=os.environ,
)
subprocess.call(
[
thirdparty_binary("gmm-transform-means"),
new_mat_path,
self.model_path,
self.model_path,
],
stderr=log_file,
env=os.environ,
)
if os.path.exists(previous_mat_path):
subprocess.call(
[
thirdparty_binary("compose-transforms"),
new_mat_path,
previous_mat_path,
composed_path,
],
stderr=log_file,
env=os.environ,
)
os.remove(previous_mat_path)
os.rename(composed_path, previous_mat_path)
else:
os.rename(new_mat_path, previous_mat_path)
[docs]
def train_iteration(self) -> None:
"""
Run a single LDA training iteration
"""
if os.path.exists(self.next_model_path):
if self.iteration <= self.final_gaussian_iteration:
self.increment_gaussians()
self.iteration += 1
return
if self.iteration in self.realignment_iterations:
self.align_iteration()
if self.iteration in self.mllt_iterations:
self.calc_lda_mllt()
self.acc_stats()
parse_logs(self.working_log_directory)
if self.iteration <= self.final_gaussian_iteration:
self.increment_gaussians()
self.iteration += 1