"""Class definitions for adapting acoustic models"""
from __future__ import annotations
import logging
import multiprocessing as mp
import os
import shutil
import subprocess
import time
from pathlib import Path
from queue import Empty
from typing import TYPE_CHECKING, List
from tqdm.rich import tqdm
from montreal_forced_aligner.abc import AdapterMixin
from montreal_forced_aligner.alignment.multiprocessing import AccStatsArguments, AccStatsFunction
from montreal_forced_aligner.alignment.pretrained import PretrainedAligner
from montreal_forced_aligner.config import GLOBAL_CONFIG
from montreal_forced_aligner.data import WorkflowType
from montreal_forced_aligner.db import CorpusWorkflow
from montreal_forced_aligner.exceptions import KaldiProcessingError
from montreal_forced_aligner.helper import mfa_open
from montreal_forced_aligner.models import AcousticModel
from montreal_forced_aligner.utils import (
KaldiProcessWorker,
Stopped,
log_kaldi_errors,
thirdparty_binary,
)
if TYPE_CHECKING:
from montreal_forced_aligner.models import MetaDict
__all__ = ["AdaptingAligner"]
logger = logging.getLogger("mfa")
[docs]
class AdaptingAligner(PretrainedAligner, AdapterMixin):
"""
Adapt an acoustic model to a new dataset
Parameters
----------
mapping_tau: int
Tau to use in mapping stats between new domain data and pretrained model
See Also
--------
:class:`~montreal_forced_aligner.alignment.pretrained.PretrainedAligner`
For dictionary, corpus, and alignment parameters
:class:`~montreal_forced_aligner.abc.AdapterMixin`
For adapting parameters
Attributes
----------
initialized: bool
Flag for whether initialization is complete
adaptation_done: bool
Flag for whether adaptation is complete
"""
def __init__(self, mapping_tau: int = 20, **kwargs):
self.initialized = False
self.adaptation_done = False
super().__init__(**kwargs)
self.mapping_tau = mapping_tau
[docs]
def map_acc_stats_arguments(self, alignment=False) -> List[AccStatsArguments]:
"""
Generate Job arguments for :func:`~montreal_forced_aligner.alignment.multiprocessing.AccStatsFunction`
Returns
-------
list[:class:`~montreal_forced_aligner.alignment.multiprocessing.AccStatsArguments`]
Arguments for processing
"""
if alignment:
model_path = self.alignment_model_path
else:
model_path = self.model_path
arguments = []
for j in self.jobs:
feat_strings = {}
for d_id in j.dictionary_ids:
feat_strings[d_id] = j.construct_feature_proc_string(
self.working_directory,
d_id,
self.feature_options["uses_splices"],
self.feature_options["splice_left_context"],
self.feature_options["splice_right_context"],
self.feature_options["uses_speaker_adaptation"],
)
arguments.append(
AccStatsArguments(
j.id,
getattr(self, "db_string", ""),
self.working_log_directory.joinpath(f"map_acc_stats.{j.id}.log"),
j.dictionary_ids,
feat_strings,
j.construct_path_dictionary(self.working_directory, "ali", "ark"),
j.construct_path_dictionary(self.working_directory, "map", "acc"),
model_path,
)
)
return arguments
[docs]
def acc_stats(self, alignment: bool = False) -> None:
"""
Accumulate stats for the mapped model
Parameters
----------
alignment: bool
Flag for whether to accumulate stats for the mapped alignment model
"""
arguments = self.map_acc_stats_arguments(alignment)
if alignment:
initial_mdl_path = self.working_directory.joinpath("unadapted.alimdl")
final_mdl_path = self.working_directory.joinpath("final.alimdl")
else:
initial_mdl_path = self.working_directory.joinpath("unadapted.mdl")
final_mdl_path = self.working_directory.joinpath("final.mdl")
logger.info("Accumulating statistics...")
with tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar:
if GLOBAL_CONFIG.use_mp:
error_dict = {}
return_queue = mp.Queue()
stopped = Stopped()
procs = []
for i, args in enumerate(arguments):
function = AccStatsFunction(args)
p = KaldiProcessWorker(i, return_queue, function, stopped)
procs.append(p)
p.start()
while True:
try:
result = return_queue.get(timeout=1)
if isinstance(result, Exception):
error_dict[getattr(result, "job_name", 0)] = result
continue
if stopped.stop_check():
continue
except Empty:
for proc in procs:
if not proc.finished.stop_check():
break
else:
break
continue
num_utterances, errors = result
pbar.update(num_utterances + errors)
for p in procs:
p.join()
if error_dict:
for v in error_dict.values():
raise v
else:
for args in arguments:
function = AccStatsFunction(args)
for num_utterances, errors in function.run():
pbar.update(num_utterances + errors)
log_path = self.working_log_directory.joinpath("map_model_est.log")
occs_path = self.working_directory.joinpath("final.occs")
with mfa_open(log_path, "w") as log_file:
acc_files = []
for j in arguments:
acc_files.extend(j.acc_paths.values())
sum_proc = subprocess.Popen(
[thirdparty_binary("gmm-sum-accs"), "-"] + acc_files,
stderr=log_file,
stdout=subprocess.PIPE,
env=os.environ,
)
ismooth_proc = subprocess.Popen(
[
thirdparty_binary("gmm-ismooth-stats"),
"--smooth-from-model",
f"--tau={self.mapping_tau}",
initial_mdl_path,
"-",
"-",
],
stderr=log_file,
stdin=sum_proc.stdout,
stdout=subprocess.PIPE,
env=os.environ,
)
est_proc = subprocess.Popen(
[
thirdparty_binary("gmm-est"),
"--update-flags=m",
f"--write-occs={occs_path}",
"--remove-low-count-gaussians=false",
initial_mdl_path,
"-",
final_mdl_path,
],
stdin=ismooth_proc.stdout,
stderr=log_file,
env=os.environ,
)
est_proc.communicate()
@property
def align_directory(self) -> str:
"""Align directory"""
return os.path.join(self.output_directory, "adapted_align")
@property
def working_log_directory(self) -> str:
"""Current log directory"""
return self.working_directory.joinpath("log")
@property
def model_path(self) -> str:
"""Current acoustic model path"""
if self.current_workflow.workflow_type == WorkflowType.acoustic_model_adaptation:
return self.working_directory.joinpath("unadapted.mdl")
return self.working_directory.joinpath("final.mdl")
@property
def alignment_model_path(self) -> str:
"""Current acoustic model path"""
if self.current_workflow.workflow_type == WorkflowType.acoustic_model_adaptation:
path = self.working_directory.joinpath("unadapted.alimdl")
if os.path.exists(path) and not getattr(self, "uses_speaker_adaptation", False):
return path
return self.model_path
return super().alignment_model_path
@property
def next_model_path(self) -> str:
"""Mapped acoustic model path"""
return self.working_directory.joinpath("final.mdl")
[docs]
def train_map(self) -> None:
"""
Trains an adapted acoustic model through mapping model states and update those with
enough data.
See Also
--------
:class:`~montreal_forced_aligner.alignment.multiprocessing.AccStatsFunction`
Multiprocessing helper function for each job
:meth:`.AdaptingAligner.map_acc_stats_arguments`
Job method for generating arguments for the helper function
:kaldi_src:`gmm-sum-accs`
Relevant Kaldi binary
:kaldi_src:`gmm-ismooth-stats`
Relevant Kaldi binary
:kaldi_src:`gmm-est`
Relevant Kaldi binary
:kaldi_steps:`train_map`
Reference Kaldi script
"""
begin = time.time()
log_directory = self.working_log_directory
os.makedirs(log_directory, exist_ok=True)
self.acc_stats(alignment=False)
if self.uses_speaker_adaptation:
self.acc_stats(alignment=True)
logger.debug(f"Mapping models took {time.time() - begin:.3f} seconds")
[docs]
def adapt(self) -> None:
"""Run the adaptation"""
logger.info("Generating initial alignments...")
self.align()
alignment_workflow = self.current_workflow
self.create_new_current_workflow(WorkflowType.acoustic_model_adaptation)
for f in ["final.mdl", "final.alimdl"]:
shutil.copyfile(
os.path.join(alignment_workflow.working_directory, f),
self.working_directory.joinpath(f).with_stem("unadapted"),
)
shutil.copyfile(
os.path.join(alignment_workflow.working_directory, "tree"),
self.working_directory.joinpath("tree"),
)
shutil.copyfile(
os.path.join(alignment_workflow.working_directory, "lda.mat"),
self.working_directory.joinpath("lda.mat"),
)
for j in self.jobs:
old_paths = j.construct_path_dictionary(
alignment_workflow.working_directory, "ali", "ark"
)
new_paths = j.construct_path_dictionary(self.working_directory, "ali", "ark")
for k, v in old_paths.items():
shutil.copyfile(v, new_paths[k])
old_paths = j.construct_path_dictionary(
alignment_workflow.working_directory, "trans", "ark"
)
new_paths = j.construct_path_dictionary(self.working_directory, "trans", "ark")
for k, v in old_paths.items():
shutil.copyfile(v, new_paths[k])
os.makedirs(self.align_directory, exist_ok=True)
try:
logger.info("Adapting pretrained model...")
self.train_map()
self.export_model(self.working_log_directory.joinpath("acoustic_model.zip"))
shutil.copyfile(
self.working_directory.joinpath("final.mdl"),
os.path.join(self.align_directory, "final.mdl"),
)
shutil.copyfile(
self.working_directory.joinpath("final.occs"),
os.path.join(self.align_directory, "final.occs"),
)
shutil.copyfile(
self.working_directory.joinpath("tree"),
os.path.join(self.align_directory, "tree"),
)
if os.path.exists(self.working_directory.joinpath("final.alimdl")):
shutil.copyfile(
self.working_directory.joinpath("final.alimdl"),
os.path.join(self.align_directory, "final.alimdl"),
)
if os.path.exists(self.working_directory.joinpath("lda.mat")):
shutil.copyfile(
self.working_directory.joinpath("lda.mat"),
os.path.join(self.align_directory, "lda.mat"),
)
wf = self.current_workflow
with self.session() as session:
session.query(CorpusWorkflow).filter(CorpusWorkflow.id == wf.id).update(
{"done": True}
)
session.commit()
except Exception as e:
wf = self.current_workflow
with self.session() as session:
session.query(CorpusWorkflow).filter(CorpusWorkflow.id == wf.id).update(
{"dirty": True}
)
session.commit()
if isinstance(e, KaldiProcessingError):
log_kaldi_errors(e.error_logs)
e.update_log_file()
raise
@property
def meta(self) -> MetaDict:
"""Acoustic model metadata"""
from datetime import datetime
from ..utils import get_mfa_version
data = {
"phones": sorted(self.non_silence_phones),
"version": get_mfa_version(),
"architecture": self.acoustic_model.meta["architecture"],
"train_date": str(datetime.now()),
"features": self.feature_options,
"phone_set_type": str(self.phone_set_type),
"dictionaries": {
"names": sorted(self.dictionary_base_names.values()),
"default": self.dictionary_base_names[self._default_dictionary_id],
"silence_word": self.silence_word,
"use_g2p": self.use_g2p,
"oov_word": self.oov_word,
"bracketed_word": self.bracketed_word,
"laughter_word": self.laughter_word,
"clitic_marker": self.clitic_marker,
"position_dependent_phones": self.position_dependent_phones,
},
"oov_phone": self.oov_phone,
"optional_silence_phone": self.optional_silence_phone,
"silence_probability": self.silence_probability,
"initial_silence_probability": self.initial_silence_probability,
"final_silence_correction": self.final_silence_correction,
"final_non_silence_correction": self.final_non_silence_correction,
}
return data
[docs]
def export_model(self, output_model_path: Path) -> None:
"""
Output an acoustic model to the specified path
Parameters
----------
output_model_path : str
Path to save adapted acoustic model
"""
directory = output_model_path.parent
acoustic_model = AcousticModel.empty(
output_model_path.stem, root_directory=self.working_log_directory
)
acoustic_model.add_meta_file(self)
acoustic_model.add_model(self.working_directory)
acoustic_model.add_model(self.phones_dir)
if directory:
os.makedirs(directory, exist_ok=True)
basename, _ = os.path.splitext(output_model_path)
acoustic_model.dump(output_model_path)