"""Class definitions for alignment mixins"""
from __future__ import annotations
import csv
import datetime
import logging
import multiprocessing as mp
import os
import time
from abc import abstractmethod
from pathlib import Path
from queue import Empty
from typing import TYPE_CHECKING, Dict, List
from tqdm.rich import tqdm
from montreal_forced_aligner.alignment.multiprocessing import (
AlignArguments,
AlignFunction,
CompileInformationArguments,
CompileTrainGraphsArguments,
CompileTrainGraphsFunction,
PhoneConfidenceArguments,
PhoneConfidenceFunction,
compile_information_func,
)
from montreal_forced_aligner.config import GLOBAL_CONFIG
from montreal_forced_aligner.db import (
CorpusWorkflow,
File,
Job,
PhoneInterval,
Speaker,
Utterance,
bulk_update,
)
from montreal_forced_aligner.dictionary.mixins import DictionaryMixin
from montreal_forced_aligner.exceptions import NoAlignmentsError
from montreal_forced_aligner.helper import mfa_open
from montreal_forced_aligner.utils import (
KaldiProcessWorker,
Stopped,
run_kaldi_function,
run_mp,
run_non_mp,
)
if TYPE_CHECKING:
from montreal_forced_aligner.abc import MetaDict
logger = logging.getLogger("mfa")
[docs]
class AlignMixin(DictionaryMixin):
"""
Configuration object for alignment
Parameters
----------
transition_scale : float
Transition scale, defaults to 1.0
acoustic_scale : float
Acoustic scale, defaults to 0.1
self_loop_scale : float
Self-loop scale, defaults to 0.1
boost_silence : float
Factor to boost silence probabilities, 1.0 is no boost or reduction
beam : int
Size of the beam to use in decoding, defaults to 10
retry_beam : int
Size of the beam to use in decoding if it fails with the initial beam width, defaults to 40
See Also
--------
:class:`~montreal_forced_aligner.dictionary.mixins.DictionaryMixin`
For dictionary parsing parameters
Attributes
----------
jobs: list[:class:`~montreal_forced_aligner.corpus.multiprocessing.Job`]
Jobs to process
"""
logger: logging.Logger
jobs: List[Job]
def __init__(
self,
transition_scale: float = 1.0,
acoustic_scale: float = 0.1,
self_loop_scale: float = 0.1,
boost_silence: float = 1.0,
beam: int = 10,
retry_beam: int = 40,
fine_tune: bool = False,
phone_confidence: bool = False,
use_phone_model: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.transition_scale = transition_scale
self.acoustic_scale = acoustic_scale
self.self_loop_scale = self_loop_scale
self.boost_silence = boost_silence
self.beam = beam
self.retry_beam = retry_beam
self.fine_tune = fine_tune
self.phone_confidence = phone_confidence
self.use_phone_model = use_phone_model
if self.retry_beam <= self.beam:
self.retry_beam = self.beam * 4
self.unaligned_files = set()
@property
def tree_path(self) -> Path:
"""Path to tree file"""
return self.working_directory.joinpath("tree")
@property
@abstractmethod
def data_directory(self) -> str:
"""Corpus data directory"""
...
[docs]
@abstractmethod
def construct_feature_proc_strings(self) -> List[Dict[str, str]]:
"""Generate feature strings"""
...
[docs]
def compile_train_graphs_arguments(self) -> List[CompileTrainGraphsArguments]:
"""
Generate Job arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.CompileTrainGraphsFunction`
Returns
-------
list[:class:`~montreal_forced_aligner.alignment.multiprocessing.CompileTrainGraphsArguments`]
Arguments for processing
"""
args = []
model_path = self.model_path
if not os.path.exists(model_path):
model_path = self.alignment_model_path
for j in self.jobs:
args.append(
CompileTrainGraphsArguments(
j.id,
getattr(self, "db_string", ""),
self.working_log_directory.joinpath(f"compile_train_graphs.{j.id}.log"),
self.working_directory.joinpath("tree"),
model_path,
getattr(self, "use_g2p", False),
)
)
return args
[docs]
def align_arguments(self) -> List[AlignArguments]:
"""
Generate Job arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.AlignFunction`
Returns
-------
list[:class:`~montreal_forced_aligner.alignment.multiprocessing.AlignArguments`]
Arguments for processing
"""
args = []
iteration = getattr(self, "iteration", None)
for j in self.jobs:
if iteration is not None:
log_path = self.working_log_directory.joinpath(f"align.{iteration}.{j.id}.log")
else:
log_path = self.working_log_directory.joinpath(f"align.{j.id}.log")
if getattr(self, "uses_speaker_adaptation", False):
log_path = log_path.with_suffix(".fmllr.log")
args.append(
AlignArguments(
j.id,
getattr(self, "db_string", ""),
log_path,
self.alignment_model_path,
self.decode_options
if self.phone_confidence
and getattr(self, "uses_speaker_adaptation", False)
and hasattr(self, "decode_options")
else self.align_options,
self.feature_options,
self.phone_confidence,
)
)
return args
[docs]
def phone_confidence_arguments(self) -> List[PhoneConfidenceArguments]:
"""
Generate Job arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.PhoneConfidenceFunction`
Returns
-------
list[:class:`~montreal_forced_aligner.alignment.multiprocessing.PhoneConfidenceArguments`]
Arguments for processing
"""
args = []
for j in self.jobs:
log_path = self.working_log_directory.joinpath(f"phone_confidence.{j.id}.log")
feat_strings = {}
for d in j.dictionaries:
feat_strings[d.id] = j.construct_feature_proc_string(
self.working_directory,
d.id,
self.feature_options["uses_splices"],
self.feature_options["splice_left_context"],
self.feature_options["splice_right_context"],
self.feature_options["uses_speaker_adaptation"],
)
args.append(
PhoneConfidenceArguments(
j.id,
getattr(self, "db_string", ""),
log_path,
self.model_path,
self.phone_pdf_counts_path,
feat_strings,
)
)
return args
@property
def align_options(self) -> MetaDict:
"""Options for use in aligning"""
return {
"transition_scale": self.transition_scale,
"acoustic_scale": self.acoustic_scale,
"self_loop_scale": self.self_loop_scale,
"beam": self.beam,
"retry_beam": self.retry_beam,
"boost_silence": self.boost_silence,
"optional_silence_csl": self.optional_silence_csl,
}
[docs]
def alignment_configuration(self) -> MetaDict:
"""Configuration parameters"""
return {
"transition_scale": self.transition_scale,
"acoustic_scale": self.acoustic_scale,
"self_loop_scale": self.self_loop_scale,
"boost_silence": self.boost_silence,
"beam": self.beam,
"retry_beam": self.retry_beam,
}
@property
def num_current_utterances(self) -> int:
"""Number of current utterances"""
return getattr(self, "num_utterances", 0)
[docs]
def compile_train_graphs(self) -> None:
"""
Multiprocessing function that compiles training graphs for utterances.
See Also
--------
:class:`~montreal_forced_aligner.alignment.multiprocessing.CompileTrainGraphsFunction`
Multiprocessing helper function for each job
:meth:`.AlignMixin.compile_train_graphs_arguments`
Job method for generating arguments for the helper function
:kaldi_steps:`align_si`
Reference Kaldi script
:kaldi_steps:`align_fmllr`
Reference Kaldi script
"""
begin = time.time()
log_directory = self.working_log_directory
os.makedirs(log_directory, exist_ok=True)
logger.info("Compiling training graphs...")
error_sum = 0
arguments = self.compile_train_graphs_arguments()
with tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar:
if GLOBAL_CONFIG.use_mp:
error_dict = {}
return_queue = mp.Queue()
stopped = Stopped()
procs = []
for i, args in enumerate(arguments):
function = CompileTrainGraphsFunction(args)
p = KaldiProcessWorker(i, return_queue, function, stopped)
procs.append(p)
p.start()
while True:
try:
result = return_queue.get(timeout=1)
if isinstance(result, Exception):
error_dict[getattr(result, "job_name", 0)] = result
continue
if stopped.stop_check():
continue
except Empty:
for proc in procs:
if not proc.finished.stop_check():
break
else:
break
continue
done, errors = result
pbar.update(done + errors)
error_sum += errors
for p in procs:
p.join()
if error_dict:
for v in error_dict.values():
raise v
else:
logger.debug("Not using multiprocessing...")
for args in arguments:
function = CompileTrainGraphsFunction(args)
for done, errors in function.run():
pbar.update(done + errors)
error_sum += errors
if error_sum:
logger.warning(f"Compilation of training graphs failed for {error_sum} utterances.")
logger.debug(f"Compiling training graphs took {time.time() - begin:.3f} seconds")
def get_phone_confidences(self):
if not os.path.exists(self.phone_pdf_counts_path):
logger.warning("Cannot calculate phone confidences with the current model.")
return
logger.info("Calculating phone confidences...")
begin = time.time()
with self.session() as session:
with tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar:
arguments = self.phone_confidence_arguments()
interval_update_mappings = []
if GLOBAL_CONFIG.use_mp:
error_dict = {}
return_queue = mp.Queue()
stopped = Stopped()
procs = []
for i, args in enumerate(arguments):
function = PhoneConfidenceFunction(args)
p = KaldiProcessWorker(i, return_queue, function, stopped)
procs.append(p)
p.start()
while True:
try:
result = return_queue.get(timeout=1)
if isinstance(result, Exception):
error_dict[getattr(result, "job_name", 0)] = result
continue
if stopped.stop_check():
continue
except Empty:
for proc in procs:
if not proc.finished.stop_check():
break
else:
break
continue
interval_update_mappings.extend(result)
pbar.update(1)
for p in procs:
p.join()
if error_dict:
for v in error_dict.values():
raise v
else:
logger.debug("Not using multiprocessing...")
for args in arguments:
function = PhoneConfidenceFunction(args)
for result in function.run():
interval_update_mappings.extend(result)
pbar.update(1)
bulk_update(session, PhoneInterval, interval_update_mappings)
session.commit()
logger.debug(f"Calculating phone confidences took {time.time() - begin:.3f} seconds")
[docs]
def align_utterances(self, training=False) -> None:
"""
Multiprocessing function that aligns based on the current model.
See Also
--------
:class:`~montreal_forced_aligner.alignment.multiprocessing.AlignFunction`
Multiprocessing helper function for each job
:meth:`.AlignMixin.align_arguments`
Job method for generating arguments for the helper function
:kaldi_steps:`align_si`
Reference Kaldi script
:kaldi_steps:`align_fmllr`
Reference Kaldi script
"""
begin = time.time()
logger.info("Generating alignments...")
with self.session() as session:
if not training:
utterances = session.query(Utterance)
if hasattr(self, "subset"):
utterances = utterances.filter(Utterance.in_subset == True) # noqa
utterances.update({"alignment_log_likelihood": None})
session.commit()
with tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar:
log_like_sum = 0
log_like_count = 0
update_mappings = []
for utterance, log_likelihood in run_kaldi_function(
AlignFunction, self.align_arguments(), pbar.update
):
if not training:
log_like_sum += log_likelihood
log_like_count += 1
update_mappings.append(
{"id": utterance, "alignment_log_likelihood": log_likelihood}
)
if not training:
if len(update_mappings) == 0:
raise NoAlignmentsError(
self.num_current_utterances, self.beam, self.retry_beam
)
bulk_update(session, Utterance, update_mappings)
session.query(Utterance).filter(
Utterance.alignment_log_likelihood != None # noqa
).update(
{
Utterance.alignment_log_likelihood: Utterance.alignment_log_likelihood
/ Utterance.num_frames
},
synchronize_session="fetch",
)
if not training:
if not getattr(self, "uses_speaker_adaptation", False):
workflow = (
session.query(CorpusWorkflow)
.filter(CorpusWorkflow.current == True) # noqa
.first()
)
workflow.time_stamp = datetime.datetime.now()
workflow.score = log_like_sum / log_like_count
session.commit()
logger.debug(f"Alignment round took {time.time() - begin:.3f} seconds")
@property
@abstractmethod
def working_directory(self) -> Path:
"""Working directory"""
...
@property
@abstractmethod
def working_log_directory(self) -> Path:
"""Working log directory"""
...
@property
def model_path(self) -> Path:
"""Acoustic model file path"""
return self.working_directory.joinpath("final.mdl")
@property
def phone_pdf_counts_path(self) -> Path:
"""Acoustic model file path"""
return self.working_directory.joinpath("phone_pdf.counts")
@property
def alignment_model_path(self) -> Path:
"""Acoustic model file path for speaker-independent alignment"""
path = self.working_directory.joinpath("final.alimdl")
if os.path.exists(path) and not getattr(self, "uses_speaker_adaptation", False):
return path
return self.model_path