"""Multiprocessing functions for training ivector extractors"""
from __future__ import annotations
import os
import re
import subprocess
import typing
from pathlib import Path
from sqlalchemy.orm import Session, joinedload
from montreal_forced_aligner.abc import MetaDict
from montreal_forced_aligner.data import MfaArguments
from montreal_forced_aligner.db import Job
from montreal_forced_aligner.helper import mfa_open
from montreal_forced_aligner.utils import KaldiFunction, thirdparty_binary
__all__ = [
"GmmGselectFunction",
"GmmGselectArguments",
"GaussToPostFunction",
"GaussToPostArguments",
"AccGlobalStatsFunction",
"AccGlobalStatsArguments",
"AccIvectorStatsFunction",
"AccIvectorStatsArguments",
]
[docs]
class GmmGselectArguments(MfaArguments):
"""Arguments for :func:`~montreal_forced_aligner.ivector.trainer.GmmGselectFunction`"""
feature_options: MetaDict
ivector_options: MetaDict
dubm_model: Path
gselect_path: Path
[docs]
class AccGlobalStatsArguments(MfaArguments):
"""Arguments for :func:`~montreal_forced_aligner.ivector.trainer.AccGlobalStatsFunction`"""
feature_options: MetaDict
ivector_options: MetaDict
gselect_path: Path
acc_path: Path
dubm_model: Path
[docs]
class GaussToPostArguments(MfaArguments):
"""Arguments for :func:`~montreal_forced_aligner.ivector.trainer.GaussToPostFunction`"""
feature_options: MetaDict
ivector_options: MetaDict
post_path: Path
dubm_model: Path
[docs]
class AccIvectorStatsArguments(MfaArguments):
"""Arguments for :func:`~montreal_forced_aligner.ivector.trainer.AccIvectorStatsFunction`"""
feature_options: MetaDict
ivector_options: MetaDict
ie_path: Path
post_path: Path
acc_path: Path
[docs]
class GmmGselectFunction(KaldiFunction):
"""
Multiprocessing function for selecting GMM indices.
See Also
--------
:meth:`.DubmTrainer.gmm_gselect`
Main function that calls this function in parallel
:meth:`.DubmTrainer.gmm_gselect_arguments`
Job method for generating arguments for this function
:kaldi_src:`subsample-feats`
Relevant Kaldi binary
:kaldi_src:`gmm-gselect`
Relevant Kaldi binary
Parameters
----------
args: :class:`~montreal_forced_aligner.ivector.trainer.GmmGselectArguments`
Arguments for the function
"""
progress_pattern = re.compile(r"^LOG.*For (?P<done_count>\d+)'th.*")
def __init__(self, args: GmmGselectArguments):
super().__init__(args)
self.feature_options = args.feature_options
self.ivector_options = args.ivector_options
self.dubm_model = args.dubm_model
self.gselect_path = args.gselect_path
def _run(self) -> typing.Generator[None]:
"""Run the function"""
if os.path.exists(self.gselect_path):
return
with Session(self.db_engine()) as session, mfa_open(self.log_path, "w") as log_file:
job: Job = (
session.query(Job)
.options(joinedload(Job.corpus, innerjoin=True))
.filter(Job.id == self.job_name)
.first()
)
current_done_count = 0
feature_string = job.construct_online_feature_proc_string()
gselect_proc = subprocess.Popen(
[
thirdparty_binary("gmm-gselect"),
f"--n={self.ivector_options['num_gselect']}",
self.dubm_model,
feature_string,
f"ark:{self.gselect_path}",
],
stderr=subprocess.PIPE,
env=os.environ,
encoding="utf8",
)
for line in gselect_proc.stderr:
log_file.write(line)
m = self.progress_pattern.match(line)
if m:
new_done_count = int(m.group("done_count"))
yield new_done_count - current_done_count
current_done_count = new_done_count
self.check_call(gselect_proc)
[docs]
class GaussToPostFunction(KaldiFunction):
"""
Multiprocessing function to get posteriors during UBM training.
See Also
--------
:meth:`.IvectorTrainer.gauss_to_post`
Main function that calls this function in parallel
:meth:`.IvectorTrainer.gauss_to_post_arguments`
Job method for generating arguments for this function
:kaldi_src:`subsample-feats`
Relevant Kaldi binary
:kaldi_src:`gmm-global-get-post`
Relevant Kaldi binary
:kaldi_src:`scale-post`
Relevant Kaldi binary
Parameters
----------
args: :class:`~montreal_forced_aligner.ivector.trainer.GaussToPostArguments`
Arguments for the function
"""
progress_pattern = re.compile(
r"^VLOG.*Processed utterance (?P<utterance>.*), average likelihood.*$"
)
def __init__(self, args: GaussToPostArguments):
super().__init__(args)
self.feature_options = args.feature_options
self.ivector_options = args.ivector_options
self.dubm_model = args.dubm_model
self.post_path = args.post_path
def _run(self) -> typing.Generator[None]:
"""Run the function"""
if os.path.exists(self.post_path):
return
modified_posterior_scale = (
self.ivector_options["posterior_scale"] * self.ivector_options["subsample"]
)
with Session(self.db_engine()) as session, mfa_open(self.log_path, "w") as log_file:
job: Job = (
session.query(Job)
.options(joinedload(Job.corpus, innerjoin=True))
.filter(Job.id == self.job_name)
.first()
)
feature_string = job.construct_online_feature_proc_string()
gmm_global_get_post_proc = subprocess.Popen(
[
thirdparty_binary("gmm-global-get-post"),
"--verbose=2",
f"--n={self.ivector_options['num_gselect']}",
f"--min-post={self.ivector_options['min_post']}",
self.dubm_model,
feature_string,
"ark:-",
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=os.environ,
)
scale_post_proc = subprocess.Popen(
[
thirdparty_binary("scale-post"),
"ark,s,cs:-",
str(modified_posterior_scale),
f"ark:{self.post_path}",
],
stdin=gmm_global_get_post_proc.stdout,
stderr=log_file,
env=os.environ,
)
for line in gmm_global_get_post_proc.stderr:
line = line.decode("utf8")
log_file.write(line)
log_file.flush()
m = self.progress_pattern.match(line)
if m:
utterance = int(m.group("utterance").split("-")[-1])
yield utterance
self.check_call(scale_post_proc)
[docs]
class AccGlobalStatsFunction(KaldiFunction):
"""
Multiprocessing function for accumulating global model stats.
See Also
--------
:meth:`.DubmTrainer.acc_global_stats`
Main function that calls this function in parallel
:meth:`.DubmTrainer.acc_global_stats_arguments`
Job method for generating arguments for this function
:kaldi_src:`subsample-feats`
Relevant Kaldi binary
:kaldi_src:`gmm-global-acc-stats`
Relevant Kaldi binary
Parameters
----------
args: :class:`~montreal_forced_aligner.ivector.trainer.AccGlobalStatsArguments`
Arguments for the function
"""
progress_pattern = re.compile(r"^VLOG.*File '(?P<file>.*)': Average likelihood =.*$")
def __init__(self, args: AccGlobalStatsArguments):
super().__init__(args)
self.feature_options = args.feature_options
self.ivector_options = args.ivector_options
self.dubm_model = args.dubm_model
self.gselect_path = args.gselect_path
self.acc_path = args.acc_path
def _run(self) -> typing.Generator[None]:
"""Run the function"""
with Session(self.db_engine()) as session, mfa_open(self.log_path, "w") as log_file:
job: Job = (
session.query(Job)
.options(joinedload(Job.corpus, innerjoin=True))
.filter(Job.id == self.job_name)
.first()
)
feature_string = job.construct_online_feature_proc_string()
command = [
thirdparty_binary("gmm-global-acc-stats"),
"--verbose=2",
f"--gselect=ark,s,cs:{self.gselect_path}",
self.dubm_model,
feature_string,
self.acc_path,
]
gmm_global_acc_proc = subprocess.Popen(
command,
stderr=subprocess.PIPE,
env=os.environ,
encoding="utf8",
)
for line in gmm_global_acc_proc.stderr:
log_file.write(line)
log_file.flush()
m = self.progress_pattern.match(line)
if m:
utt_id = int(m.group("file").split("-")[-1])
yield utt_id
self.check_call(gmm_global_acc_proc)
[docs]
class AccIvectorStatsFunction(KaldiFunction):
"""
Multiprocessing function that accumulates stats for ivector training.
See Also
--------
:meth:`.IvectorTrainer.acc_ivector_stats`
Main function that calls this function in parallel
:meth:`.IvectorTrainer.acc_ivector_stats_arguments`
Job method for generating arguments for this function
:kaldi_src:`subsample-feats`
Relevant Kaldi binary
:kaldi_src:`ivector-extractor-acc-stats`
Relevant Kaldi binary
Parameters
----------
args: :class:`~montreal_forced_aligner.ivector.trainer.AccIvectorStatsArguments`
Arguments for the function
"""
progress_pattern = re.compile(r"VLOG.* Per frame, auxf is: weight.*")
def __init__(self, args: AccIvectorStatsArguments):
super().__init__(args)
self.feature_options = args.feature_options
self.ivector_options = args.ivector_options
self.ie_path = args.ie_path
self.post_path = args.post_path
self.acc_path = args.acc_path
def _run(self) -> typing.Generator[None]:
"""Run the function"""
with Session(self.db_engine()) as session, mfa_open(self.log_path, "w") as log_file:
job: Job = (
session.query(Job)
.options(joinedload(Job.corpus, innerjoin=True))
.filter(Job.id == self.job_name)
.first()
)
feature_string = job.construct_online_feature_proc_string()
acc_stats_proc = subprocess.Popen(
[
thirdparty_binary("ivector-extractor-acc-stats"),
"--verbose=4",
self.ie_path,
feature_string,
f"ark,s,cs:{self.post_path}",
self.acc_path,
],
stderr=subprocess.PIPE,
env=os.environ,
encoding="utf8",
)
for line in acc_stats_proc.stderr:
m = self.progress_pattern.match(line)
if m:
yield 1
continue
elif "VLOG" in line:
continue
log_file.write(line)
log_file.flush()
self.check_call(acc_stats_proc)