"""
Alignment multiprocessing functions
-----------------------------------
"""
from __future__ import annotations
import collections
import json
import logging
import math
import multiprocessing as mp
import os
import shutil
import statistics
import sys
import time
import traceback
import typing
from dataclasses import dataclass
from pathlib import Path
from queue import Empty
from typing import TYPE_CHECKING
import librosa.feature
import numpy as np
import sqlalchemy
from _kalpy.gmm import gmm_compute_likes
from _kalpy.hmm import TransitionModel
from _kalpy.util import (
RandomAccessBaseDoubleMatrixReader,
RandomAccessBaseFloatMatrixReader,
RandomAccessInt32VectorVectorReader,
)
from kalpy.aligner import KalpyAligner
from kalpy.decoder.data import FstArchive
from kalpy.decoder.training_graphs import TrainingGraphCompiler
from kalpy.evaluation import align_words, fix_unk_words
from kalpy.fstext.lexicon import LexiconCompiler
from kalpy.gmm.align import GmmAligner
from kalpy.gmm.data import AlignmentArchive, TranscriptionArchive
from kalpy.gmm.train import GmmStatsAccumulator
from kalpy.gmm.utils import read_gmm_model
from kalpy.utils import generate_read_specifier
from sqlalchemy.orm import joinedload, selectinload, subqueryload
from montreal_forced_aligner.abc import KaldiFunction
from montreal_forced_aligner.data import (
WORD_BEGIN_SYMBOL,
WORD_END_SYMBOL,
MfaArguments,
PhoneType,
PronunciationProbabilityCounter,
WordType,
WorkflowType,
)
from montreal_forced_aligner.db import (
CorpusWorkflow,
File,
Job,
Phone,
PhoneInterval,
ReferencePhoneInterval,
SoundFile,
Speaker,
TextFile,
Utterance,
Word,
WordInterval,
)
from montreal_forced_aligner.exceptions import AlignmentCollectionError, AlignmentExportError
from montreal_forced_aligner.helper import mfa_open, split_phone_position
from montreal_forced_aligner.models import AcousticModel
from montreal_forced_aligner.textgrid import construct_textgrid_output
from montreal_forced_aligner.utils import thread_logger
if TYPE_CHECKING:
from montreal_forced_aligner.abc import MetaDict
__all__ = [
"AlignmentExtractionFunction",
"ExportTextGridProcessWorker",
"AlignmentExtractionArguments",
"ExportTextGridArguments",
"AlignFunction",
"AlignArguments",
"AnalyzeAlignmentsFunction",
"AnalyzeAlignmentsArguments",
"AnalyzeTranscriptsFunction",
"AccStatsFunction",
"AccStatsArguments",
"CompileTrainGraphsFunction",
"CompileTrainGraphsArguments",
"GeneratePronunciationsArguments",
"GeneratePronunciationsFunction",
"FineTuneArguments",
"FineTuneFunction",
"PhoneConfidenceArguments",
"PhoneConfidenceFunction",
]
logger = logging.getLogger("mfa")
[docs]
@dataclass
class GeneratePronunciationsArguments(MfaArguments):
"""
Arguments for :func:`~montreal_forced_aligner.alignment.multiprocessing.GeneratePronunciationsFunction`
Parameters
----------
job_name: int
Integer ID of the job
session: :class:`sqlalchemy.orm.scoped_session` or str
SqlAlchemy scoped session or string for database connections
log_path: :class:`~pathlib.Path`
Path to save logging information during the run
aligner: :class:`kalpy.gmm.align.GmmAligner`
GmmAligner to use
lexicon_compilers: dict[int, :class:`kalpy.fstext.lexicon.LexiconCompiler`]
Lexicon compilers for each pronunciation dictionary
for_g2p: bool
Flag for training a G2P model with acoustic information
"""
aligner: GmmAligner
lexicon_compilers: typing.Dict[int, LexiconCompiler]
for_g2p: bool
[docs]
@dataclass
class ExportTextGridArguments(MfaArguments):
"""
Arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.ExportTextGridProcessWorker`
Parameters
----------
job_name: int
Integer ID of the job
session: :class:`sqlalchemy.orm.scoped_session` or str
SqlAlchemy scoped session or string for database connections
log_path: :class:`~pathlib.Path`
Path to save logging information during the run
export_frame_shift: float
Frame shift in seconds
cleanup_textgrids: bool
Flag to cleanup silences and recombine words
clitic_marker: str
Marker indicating clitics
output_directory: :class:`~pathlib.Path`
Directory for exporting
output_format: str
Format to export
include_original_text: bool
Flag for including original unnormalized text as a tier
"""
export_frame_shift: float
cleanup_textgrids: bool
clitic_marker: str
output_directory: Path
output_format: str
include_original_text: bool
[docs]
@dataclass
class CompileTrainGraphsArguments(MfaArguments):
"""
Arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.CompileTrainGraphsFunction`
Parameters
----------
job_name: int
Integer ID of the job
session: :class:`sqlalchemy.orm.scoped_session` or str
SqlAlchemy scoped session or string for database connections
log_path: :class:`~pathlib.Path`
Path to save logging information during the run
working_directory: :class:`~pathlib.Path`
Working directory
tree_path: :class:`~pathlib.Path`
Path to tree file
model_path: :class:`~pathlib.Path`
Path to model file
use_g2p: bool
Flag for whether acoustic model uses g2p
"""
working_directory: Path
lexicon_compilers: typing.Dict[int, LexiconCompiler]
tree_path: Path
model_path: Path
use_g2p: bool
[docs]
@dataclass
class AlignArguments(MfaArguments):
"""
Arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.AlignFunction`
Parameters
----------
job_name: int
Integer ID of the job
session: :class:`sqlalchemy.orm.scoped_session` or str
SqlAlchemy scoped session or string for database connections
log_path: :class:`~pathlib.Path`
Path to save logging information during the run
working_directory: :class:`~pathlib.Path`
Working directory
model_path: :class:`~pathlib.Path`
Path to model file
align_options: dict[str, Any]
Alignment options
final: bool
Flag for final alignment pass
"""
working_directory: Path
model_path: Path
align_options: MetaDict
confidence: bool
final: bool
@dataclass
class AnalyzeAlignmentsArguments(MfaArguments):
"""
Arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.AnalyzeAlignmentsFunction`
Parameters
----------
job_name: int
Integer ID of the job
session: :class:`sqlalchemy.orm.scoped_session` or str
SqlAlchemy scoped session or string for database connections
log_path: :class:`~pathlib.Path`
Path to save logging information during the run
model_path: :class:`~pathlib.Path`
Path to model file
align_options: dict[str, Any]
Alignment options
"""
model_path: Path
align_options: MetaDict
[docs]
@dataclass
class FineTuneArguments(MfaArguments):
"""
Arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.AlignFunction`
Parameters
----------
job_name: int
Integer ID of the job
session: :class:`sqlalchemy.orm.scoped_session` or str
SqlAlchemy scoped session or string for database connections
log_path: :class:`~pathlib.Path`
Path to save logging information during the run
acoustic_model: :class:`~montreal_forced_aligner.models.AcousticModel`
Acoustic model
lexicon_compilers: dict[int, :class:`~kalpy.lexicon.LexiconCompiler`]
Lexicon compilers
boundary_tolerance: float, optional
Boundary tolerance, defaults to half the feature generation time step
"""
acoustic_model: AcousticModel
lexicon_compilers: typing.Dict[int, LexiconCompiler]
boundary_tolerance: typing.Optional[float]
[docs]
@dataclass
class PhoneConfidenceArguments(MfaArguments):
"""
Arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.AlignFunction`
Parameters
----------
job_name: int
Integer ID of the job
session: :class:`sqlalchemy.orm.scoped_session` or str
SqlAlchemy scoped session or string for database connections
log_path: :class:`~pathlib.Path`
Path to save logging information during the run
working_directory: :class:`~pathlib.Path`
Working directory
model_path: :class:`~pathlib.Path`
Path to model file
phone_pdf_counts_path: :class:`~pathlib.Path`
Path to output PDF counts
"""
working_directory: Path
model_path: Path
phone_pdf_counts_path: Path
[docs]
@dataclass
class AccStatsArguments(MfaArguments):
"""
Arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.AccStatsFunction`
Parameters
----------
job_name: int
Integer ID of the job
session: :class:`sqlalchemy.orm.scoped_session` or str
SqlAlchemy scoped session or string for database connections
log_path: :class:`~pathlib.Path`
Path to save logging information during the run
working_directory: :class:`~pathlib.Path`
Path to working directory
model_path: :class:`~pathlib.Path`
Path to model file
"""
working_directory: Path
model_path: Path
filter_likely_errors: bool = False
[docs]
class CompileTrainGraphsFunction(KaldiFunction):
"""
Multiprocessing function to compile training graphs
See Also
--------
:meth:`.AlignMixin.compile_train_graphs`
Main function that calls this function in parallel
:meth:`.AlignMixin.compile_train_graphs_arguments`
Job method for generating arguments for this function
:kaldi_src:`compile-train-graphs`
Relevant Kaldi binary
Parameters
----------
args: :class:`~montreal_forced_aligner.alignment.multiprocessing.CompileTrainGraphsArguments`
Arguments for the function
"""
def __init__(self, args: CompileTrainGraphsArguments):
super().__init__(args)
self.tree_path = args.tree_path
self.lexicon_compilers = args.lexicon_compilers
self.model_path = args.model_path
self.use_g2p = args.use_g2p
def _run(self):
"""Run the function"""
with self.session() as session, thread_logger(
"kalpy.graphs", self.log_path, job_name=self.job_name
) as graph_logger:
graph_logger.debug(f"Tree path: {self.tree_path}")
graph_logger.debug(f"Model path: {self.model_path}")
job = (
session.query(Job)
.options(joinedload(Job.corpus, innerjoin=True), subqueryload(Job.dictionaries))
.filter(Job.id == self.job_name)
.first()
)
workflow: CorpusWorkflow = (
session.query(CorpusWorkflow)
.filter(CorpusWorkflow.current == True) # noqa
.first()
)
interjection_costs = {}
if workflow.workflow_type is WorkflowType.transcript_verification:
interjection_words = (
session.query(Word).filter(Word.word_type == WordType.interjection).all()
)
if interjection_words:
max_count = max(math.log(x.count) for x in interjection_words)
for w in interjection_words:
count = math.log(w.count)
if count == 0:
count = 0.01
cost = max_count / count
interjection_costs[w.word] = cost
if self.use_g2p:
text_column = Utterance.normalized_character_text
else:
text_column = Utterance.normalized_text
for d in job.training_dictionaries:
begin = time.time()
if self.lexicon_compilers and d.id in self.lexicon_compilers:
lexicon = self.lexicon_compilers[d.id]
else:
lexicon = d.lexicon_compiler
if workflow.workflow_type is WorkflowType.transcript_verification:
if interjection_words and d.oov_word not in interjection_costs:
interjection_costs[d.oov_word] = min(interjection_costs.values())
# interjection_costs[d.cutoff_word] = min(interjection_costs.values())
compiler = TrainingGraphCompiler(
self.model_path,
self.tree_path,
lexicon,
use_g2p=self.use_g2p,
batch_size=500
if workflow.workflow_type is not WorkflowType.transcript_verification
else 250,
)
graph_logger.debug(f"Set up took {time.time() - begin} seconds")
query = (
session.query(Utterance.kaldi_id, text_column)
.join(Utterance.speaker)
.filter(Utterance.job_id == self.job_name, Speaker.dictionary_id == d.id)
.filter(Utterance.ignored == False) # noqa
.order_by(Utterance.kaldi_id)
)
if job.corpus.current_subset > 0:
query = query.filter(Utterance.in_subset == True) # noqa
graph_logger.info(f"Compiling graphs for {d.name}")
fst_ark_path = job.construct_path(
workflow.working_directory, "fsts", "ark", d.name
)
compiler.export_graphs(
fst_ark_path,
query,
# callback=self.callback,
interjection_words=interjection_costs,
# cutoff_pattern = d.cutoff_word
)
graph_logger.debug(f"Total compilation time: {time.time() - begin} seconds")
del compiler
del lexicon
[docs]
class AccStatsFunction(KaldiFunction):
"""
Multiprocessing function for accumulating stats in GMM training.
See Also
--------
:meth:`.AcousticModelTrainingMixin.acc_stats`
Main function that calls this function in parallel
:meth:`.AcousticModelTrainingMixin.acc_stats_arguments`
Job method for generating arguments for this function
:kaldi_src:`gmm-acc-stats-ali`
Relevant Kaldi binary
Parameters
----------
args: :class:`~montreal_forced_aligner.alignment.multiprocessing.AccStatsArguments`
Arguments for the function
"""
def __init__(self, args: AccStatsArguments):
super().__init__(args)
self.working_directory = args.working_directory
self.model_path = args.model_path
self.filter_likely_errors = args.filter_likely_errors
def _run(self) -> None:
"""Run the function"""
with self.session() as session, thread_logger(
"kalpy.train", self.log_path, job_name=self.job_name
) as train_logger:
job: Job = (
session.query(Job)
.options(joinedload(Job.corpus, innerjoin=True), subqueryload(Job.dictionaries))
.filter(Job.id == self.job_name)
.first()
)
for d in job.training_dictionaries:
train_logger.debug(f"Accumulating stats for dictionary {d.name} ({d.id})")
train_logger.debug(f"Accumulating stats for model: {self.model_path}")
accumulator = GmmStatsAccumulator(self.model_path)
ignored_keys = None
if self.filter_likely_errors:
ignored_keys = set(
x[0]
for x in session.query(Utterance.kaldi_id)
.filter(Utterance.job_id == self.job_name) # noqa
.filter(
sqlalchemy.or_(
sqlalchemy.and_(
Utterance.manual_alignments == True, # noqa
Utterance.alignment_log_likelihood < -1000,
),
sqlalchemy.and_(
Utterance.manual_alignments == False, # noqa
Utterance.duration_deviation > 10,
),
sqlalchemy.and_(
Utterance.manual_alignments == False, # noqa
Utterance.snr <= 1.0,
),
)
)
.all()
)
train_logger.debug(
f"Ignoring {len(ignored_keys)} utterances due to likely alignment errors"
)
feature_archive = job.construct_feature_archive(self.working_directory, d.name)
ali_path = job.construct_path(self.working_directory, "ali", "ark", d.name)
if not ali_path.exists():
continue
alignment_archive = AlignmentArchive(ali_path)
train_logger.debug("Feature Archive information:")
train_logger.debug(f"CMVN: {feature_archive.cmvn_read_specifier}")
train_logger.debug(f"Deltas: {feature_archive.use_deltas}")
train_logger.debug(f"Splices: {feature_archive.use_splices}")
train_logger.debug(f"LDA: {feature_archive.lda_mat_file_name}")
train_logger.debug(f"fMLLR: {feature_archive.transform_read_specifier}")
train_logger.debug(f"Alignment path: {ali_path}")
accumulator.accumulate_stats(
feature_archive,
alignment_archive,
callback=self.callback,
ignored_keys=ignored_keys,
)
self.callback((accumulator.transition_accs, accumulator.gmm_accs))
[docs]
class AlignFunction(KaldiFunction):
"""
Multiprocessing function for alignment.
See Also
--------
:meth:`.AlignMixin.align_utterances`
Main function that calls this function in parallel
:meth:`.AlignMixin.align_arguments`
Job method for generating arguments for this function
:kaldi_src:`align-gmm-compiled`
Relevant Kaldi binary
:kaldi_src:`gmm-boost-silence`
Relevant Kaldi binary
Parameters
----------
args: :class:`~montreal_forced_aligner.alignment.multiprocessing.AlignArguments`
Arguments for the function
"""
def __init__(self, args: AlignArguments):
super().__init__(args)
self.working_directory = args.working_directory
self.model_path = args.model_path
self.align_options = args.align_options
self.confidence = args.confidence
self.final = args.final
def _run(self) -> None:
"""Run the function"""
with self.session() as session, thread_logger(
"kalpy.align", self.log_path, job_name=self.job_name
) as align_logger:
align_logger.debug(f"Align options: {self.align_options}")
job: Job = (
session.query(Job)
.options(joinedload(Job.corpus, innerjoin=True), subqueryload(Job.dictionaries))
.filter(Job.id == self.job_name)
.first()
)
align_options = self.align_options
boost_silence = align_options.pop("boost_silence", 1.0)
silence_phones = [
x
for x, in session.query(Phone.mapping_id).filter(
Phone.phone_type == PhoneType.silence, Phone.phone != "<eps>"
)
]
aligner = GmmAligner(
self.model_path,
**align_options,
)
aligner.boost_silence(boost_silence, silence_phones)
for d in job.training_dictionaries:
align_logger.debug(f"Aligning for dictionary {d.name} ({d.id})")
align_logger.debug(f"Aligning with model: {aligner.acoustic_model_path}")
fst_path = job.construct_path(self.working_directory, "fsts", "ark", d.name)
align_logger.debug(f"Training graph archive: {fst_path}")
feature_archive = job.construct_feature_archive(self.working_directory, d.name)
align_logger.debug("Feature Archive information:")
align_logger.debug(f"Archive: {feature_archive.file_name}")
align_logger.debug(f"CMVN: {feature_archive.cmvn_read_specifier}")
align_logger.debug(f"Deltas: {feature_archive.use_deltas}")
align_logger.debug(f"Splices: {feature_archive.use_splices}")
align_logger.debug(f"LDA: {feature_archive.lda_mat_file_name}")
align_logger.debug(f"fMLLR: {feature_archive.transform_read_specifier}")
training_graph_archive = FstArchive(fst_path)
ali_path = job.construct_path(self.working_directory, "ali", "ark", d.name)
words_path = job.construct_path(self.working_directory, "words", "ark", d.name)
likes_path = job.construct_path(
self.working_directory, "likelihoods", "ark", d.name
)
ali_path.unlink(missing_ok=True)
words_path.unlink(missing_ok=True)
likes_path.unlink(missing_ok=True)
reference_phones_path = job.construct_path(
job.corpus.current_subset_directory, "ref_phones", "ark", d.name
)
if reference_phones_path.exists():
align_logger.debug(f"Reference phones: {reference_phones_path}")
reference_phone_archive = RandomAccessInt32VectorVectorReader(
generate_read_specifier(reference_phones_path)
)
else:
reference_phone_archive = None
if aligner.acoustic_model_path.endswith(".alimdl"):
ali_path = job.construct_path(
self.working_directory, "ali_first_pass", "ark", d.name
)
words_path = job.construct_path(
self.working_directory, "words_first_pass", "ark", d.name
)
likes_path = job.construct_path(
self.working_directory, "likelihoods_first_pass", "ark", d.name
)
try:
aligner.export_alignments(
ali_path,
training_graph_archive,
feature_archive,
reference_phone_archive=reference_phone_archive,
word_file_name=words_path,
likelihood_file_name=likes_path,
callback=self.callback,
)
except Exception:
exc_type, exc_value, exc_traceback = sys.exc_info()
error_text = "\n".join(
traceback.format_exception(exc_type, exc_value, exc_traceback)
)
align_logger.debug(error_text)
raise
finally:
if reference_phone_archive is not None:
reference_phone_archive.Close()
if aligner.acoustic_model_path.endswith(".alimdl"):
try:
job.construct_path(
self.working_directory, "ali", "ark", d.name
).symlink_to(ali_path)
job.construct_path(
self.working_directory, "words", "ark", d.name
).symlink_to(words_path)
job.construct_path(
self.working_directory, "likelihoods", "ark", d.name
).symlink_to(likes_path)
except OSError as e:
logger.debug(str(e))
shutil.copyfile(
ali_path,
job.construct_path(self.working_directory, "ali", "ark", d.name),
)
shutil.copyfile(
words_path,
job.construct_path(self.working_directory, "words", "ark", d.name),
)
shutil.copyfile(
likes_path,
job.construct_path(
self.working_directory, "likelihoods", "ark", d.name
),
)
class AnalyzeAlignmentsFunction(KaldiFunction):
"""
Multiprocessing function for analyzing alignments.
See Also
--------
:meth:`.CorpusAligner.analyze_alignments`
Main function that calls this function in parallel
:meth:`.CorpusAligner.calculate_speech_post_arguments`
Job method for generating arguments for this function
:kaldi_src:`lattice-to-post`
Relevant Kaldi binary
:kaldi_src:`weight-silence-post`
Relevant Kaldi binary
Parameters
----------
args: :class:`~montreal_forced_aligner.alignment.multiprocessing.CalculateSpeechPostArguments`
Arguments for the function
"""
def __init__(self, args: AnalyzeAlignmentsArguments):
super().__init__(args)
self.model_path = args.model_path
self.align_options = args.align_options
def _run(self):
"""Run the function"""
with self.session() as session, thread_logger(
"kalpy.align", self.log_path, job_name=self.job_name
) as extraction_logger:
job: Job = (
session.query(Job)
.options(joinedload(Job.corpus, innerjoin=True), subqueryload(Job.dictionaries))
.filter(Job.id == self.job_name)
.first()
)
phones = {
k: (m, sd)
for k, m, sd in session.query(
Phone.id, Phone.mean_duration, Phone.sd_duration
).filter(
Phone.phone_type == PhoneType.non_silence,
Phone.sd_duration != None, # noqa
Phone.sd_duration != 0,
)
}
for d in job.training_dictionaries:
query = (
session.query(Utterance)
.join(Utterance.speaker)
.filter(
Utterance.job_id == job.id,
Speaker.dictionary_id == d.id,
Utterance.alignment_log_likelihood != None, # noqa
)
.options(
joinedload(Utterance.speaker, innerjoin=True),
joinedload(Utterance.file, innerjoin=True).joinedload(
File.sound_file, innerjoin=True
),
)
)
silence_phone_id = (
session.query(Phone.id)
.filter(Phone.phone == d.optional_silence_phone)
.first()[0]
)
for utterance in query:
phone_intervals = (
session.query(PhoneInterval)
.join(PhoneInterval.phone)
.filter(
PhoneInterval.utterance_id == utterance.id,
)
.all()
)
if not phone_intervals:
continue
audio = utterance.segment.load_audio()
sr = 16_000
rms = librosa.feature.rms(
y=audio, frame_length=int(0.025 * sr), hop_length=int(0.01 * sr)
)[0, ...]
interval_count = 0
log_like_sum = 0
duration_zscore_max = 0
silence_energy_sum = 0
silence_frame_count = 0
nonsilence_energy_sum = 0
nonsilence_frame_count = 0
for pi in phone_intervals:
begin_index = int((pi.begin - utterance.begin) / 0.01)
end_index = int((pi.end - utterance.begin) / 0.01)
if pi.phone_id == silence_phone_id:
silence_frame_count += end_index - begin_index
silence_energy_sum += rms[begin_index:end_index].sum()
elif pi.phone_id in phones:
nonsilence_frame_count += end_index - begin_index
nonsilence_energy_sum += rms[begin_index:end_index].sum()
if pi.phone_id not in phones:
continue
interval_count += 1
log_like_sum += pi.phone_goodness
m, sd = phones[pi.phone_id]
duration_zscore = abs((pi.duration - m) / sd)
if duration_zscore > duration_zscore_max:
duration_zscore_max = duration_zscore
try:
utterance_speech_log_likelihood = log_like_sum / interval_count
utterance_duration_deviation = duration_zscore_max
except ZeroDivisionError:
utterance_speech_log_likelihood = None
utterance_duration_deviation = None
try:
silence_energy = silence_energy_sum / silence_frame_count
except ZeroDivisionError:
silence_energy = rms.min()
try:
nonsilence_energy = nonsilence_energy_sum / nonsilence_frame_count
except ZeroDivisionError:
nonsilence_energy = rms.max()
nonsilence_energy = 10 * math.log10(
max(nonsilence_energy, sys.float_info.epsilon)
)
silence_energy = 10 * math.log10(max(silence_energy, sys.float_info.epsilon))
snr = nonsilence_energy - silence_energy
extraction_logger.debug(
f"{utterance.id}: Energy shape: {rms.shape}, Nonsilence energy: {nonsilence_energy} ({nonsilence_frame_count}), Silence energy: {silence_energy} ({silence_frame_count}), SNR: {snr}"
)
self.callback(
(
utterance.id,
utterance_speech_log_likelihood,
utterance_duration_deviation,
snr,
)
)
class AnalyzeTranscriptsFunction(KaldiFunction):
"""
Multiprocessing function for analyzing alignments.
See Also
--------
:meth:`.CorpusAligner.analyze_alignments`
Main function that calls this function in parallel
:meth:`.CorpusAligner.calculate_speech_post_arguments`
Job method for generating arguments for this function
:kaldi_src:`lattice-to-post`
Relevant Kaldi binary
:kaldi_src:`weight-silence-post`
Relevant Kaldi binary
Parameters
----------
args: :class:`~montreal_forced_aligner.alignment.multiprocessing.CalculateSpeechPostArguments`
Arguments for the function
"""
def __init__(self, args: AnalyzeAlignmentsArguments):
super().__init__(args)
self.model_path = args.model_path
self.align_options = args.align_options
def _run(self):
"""Run the function"""
with self.session() as session:
job: Job = (
session.query(Job)
.options(joinedload(Job.corpus, innerjoin=True), subqueryload(Job.dictionaries))
.filter(Job.id == self.job_name)
.first()
)
query = session.query(Utterance).filter(
Utterance.job_id == job.id, Utterance.alignment_log_likelihood != None # noqa
)
for utterance in query:
word_intervals = [
x.as_ctm()
for x in (
session.query(WordInterval)
.join(WordInterval.word)
.filter(
WordInterval.utterance_id == utterance.id,
Word.word_type != WordType.silence,
WordInterval.end - WordInterval.begin > 0.03,
)
.options(
joinedload(WordInterval.word, innerjoin=True),
)
.order_by(WordInterval.begin)
)
]
if not word_intervals:
continue
extra_duration, wer, aligned_duration = align_words(
utterance.normalized_text.split(), word_intervals, "<eps>", debug=True
)
transcript = " ".join(x.label for x in word_intervals)
self.callback((utterance.id, wer, extra_duration, transcript))
[docs]
class FineTuneFunction(KaldiFunction):
"""
Multiprocessing function for fine-tuning alignment.
Parameters
----------
args: :class:`~montreal_forced_aligner.alignment.multiprocessing.FineTuneArguments`
Arguments for the function
"""
def __init__(self, args: FineTuneArguments):
super().__init__(args)
self.acoustic_model = args.acoustic_model
self.lexicon_compilers = args.lexicon_compilers
self.boundary_tolerance = args.boundary_tolerance
def _run(self):
"""Run the function"""
with self.session() as session, thread_logger(
"kalpy.align", self.log_path, job_name=self.job_name
):
job = (
session.query(Job)
.options(joinedload(Job.corpus, innerjoin=True), subqueryload(Job.dictionaries))
.filter(Job.id == self.job_name)
.first()
)
workflow: CorpusWorkflow = (
session.query(CorpusWorkflow)
.filter(CorpusWorkflow.current == True) # noqa
.first()
)
cmvn_paths = job.per_dictionary_cmvn_scp_paths
trans_paths = job.per_dictionary_trans_scp_paths
for d in job.dictionaries:
aligner = KalpyAligner(
acoustic_model=self.acoustic_model,
lexicon_compiler=self.lexicon_compilers[d.id],
)
ali_path = job.construct_path(workflow.working_directory, "ali", ".ark", d.name)
alignment_archive = AlignmentArchive(ali_path)
utterance_query = (
session.query(Utterance, SoundFile.sound_file_path)
.join(Utterance.file)
.join(Utterance.speaker)
.join(File.sound_file)
.filter(Utterance.job_id == self.job_name, Speaker.dictionary_id == d.id)
.order_by(Utterance.kaldi_id)
)
cmvn_reader = None
cmvn_path = cmvn_paths[d.id]
if cmvn_path.exists():
cmvn_read_specifier = generate_read_specifier(cmvn_path)
cmvn_reader = RandomAccessBaseDoubleMatrixReader(cmvn_read_specifier)
fmllr_path = trans_paths[d.id]
transform_reader = None
if fmllr_path.exists():
transform_read_specifier = generate_read_specifier(fmllr_path)
transform_reader = RandomAccessBaseFloatMatrixReader(transform_read_specifier)
current_speaker = None
current_transform = None
current_cmvn = None
for utterance, sf_path in utterance_query:
try:
alignment = alignment_archive[utterance.id]
except KeyError:
continue
interval_query = (
session.query(PhoneInterval)
.filter(
PhoneInterval.utterance_id == utterance.id,
)
.order_by(PhoneInterval.begin)
)
if utterance.speaker_id != current_speaker:
current_speaker = utterance.speaker_id
if cmvn_reader is not None and cmvn_reader.HasKey(str(current_speaker)):
current_cmvn = cmvn_reader.Value(str(current_speaker))
if transform_reader is not None and transform_reader.HasKey(
str(current_speaker)
):
current_transform = transform_reader.Value(str(current_speaker))
ctm = aligner.fine_tune_alignments(
utterance.to_kalpy(),
alignment,
boundary_tolerance=self.boundary_tolerance,
cmvn=current_cmvn,
fmllr_trans=current_transform,
)
new_boundaries = ctm.phone_boundaries
interval_mapping = []
for i, interval in enumerate(interval_query):
begin = interval.begin
if i != 0:
begin = new_boundaries[i - 1]
if i < len(new_boundaries) - 1:
interval_mapping.append(
{"id": interval.id, "begin": begin, "end": new_boundaries[i]}
)
self.callback(interval_mapping)
[docs]
class PhoneConfidenceFunction(KaldiFunction):
"""
Multiprocessing function to calculate phone confidence metrics
See Also
--------
:kaldi_src:`gmm-compute-likes`
Relevant Kaldi binary
:kaldi_src:`transform-feats`
Relevant Kaldi binary
Parameters
----------
args: :class:`~montreal_forced_aligner.alignment.multiprocessing.PhoneConfidenceArguments`
Arguments for the function
"""
def __init__(self, args: PhoneConfidenceArguments):
super().__init__(args)
self.working_directory = args.working_directory
self.model_path = args.model_path
self.phone_pdf_counts_path = args.phone_pdf_counts_path
def _run(self):
"""Run the function"""
with mfa_open(self.phone_pdf_counts_path, "r") as f:
data = json.load(f)
phone_pdf_mapping = collections.defaultdict(collections.Counter)
for phone, pdf_counts in data.items():
phone = split_phone_position(phone)[0]
for pdf, count in pdf_counts.items():
phone_pdf_mapping[phone][int(pdf)] += count
phones = {p: i for i, p in enumerate(sorted(phone_pdf_mapping.keys()))}
reversed_phones = {k: v for v, k in phones.items()}
for phone, pdf_counts in phone_pdf_mapping.items():
phone_total = sum(pdf_counts.values())
for pdf, count in pdf_counts.items():
phone_pdf_mapping[phone][int(pdf)] = count / phone_total
_, acoustic_model = read_gmm_model(self.model_path)
with self.session() as session:
job: typing.Optional[Job] = session.get(
Job, self.job_name, options=[joinedload(Job.dictionaries), joinedload(Job.corpus)]
)
utterances = (
session.query(Utterance)
.filter(Utterance.job_id == self.job_name)
.options(
selectinload(Utterance.phone_intervals).joinedload(
PhoneInterval.phone, innerjoin=True
),
selectinload(Utterance.reference_phone_intervals).options(
joinedload(ReferencePhoneInterval.phone, innerjoin=True),
),
)
)
utterances = {u.id: (u.begin, u.phone_intervals) for u in utterances}
for d in job.dictionaries:
feature_archive = job.construct_feature_archive(self.working_directory, d.name)
interval_mappings = []
for utterance_id, feats in feature_archive:
utterance_id = int(utterance_id.split("-")[-1])
likelihoods = gmm_compute_likes(acoustic_model, feats).numpy()
phone_likes = np.zeros((likelihoods.shape[0], len(phones)))
for i, p in reversed_phones.items():
like = likelihoods[:, [x for x in phone_pdf_mapping[p].keys()]]
weight = np.array([x for x in phone_pdf_mapping[p].values()])
phone_likes[:, i] = np.dot(like, weight)
top_phone_inds = np.argmax(phone_likes, axis=1)
utt_begin, intervals = utterances[utterance_id]
for pi in intervals:
if pi.phone.phone == "sil":
continue
frame_begin = int(((pi.begin - utt_begin) * 1000) / 10)
frame_end = int(((pi.end - utt_begin) * 1000) / 10)
if frame_begin == frame_end:
frame_end += 1
frame_end = min(frame_end, top_phone_inds.shape[0])
alternate_labels = collections.Counter()
scores = []
for i in range(frame_begin, frame_end):
top_phone_ind = top_phone_inds[i]
alternate_label = reversed_phones[top_phone_ind]
alternate_label = split_phone_position(alternate_label)[0]
alternate_labels[alternate_label] += 1
if alternate_label == pi.phone.phone:
scores.append(0)
else:
actual_score = phone_likes[i, phones[pi.phone.phone]]
scores.append(phone_likes[i, top_phone_ind] - actual_score)
average_score = statistics.mean(scores)
interval_mappings.append(
{"id": pi.id, "phone_goodness": float(average_score)}
)
self.callback(interval_mappings)
interval_mappings = []
[docs]
class GeneratePronunciationsFunction(KaldiFunction):
"""
Multiprocessing function for generating pronunciations
See Also
--------
:meth:`.DictionaryTrainer.export_lexicons`
Main function that calls this function in parallel
:meth:`.CorpusAligner.generate_pronunciations_arguments`
Job method for generating arguments for this function
:kaldi_src:`linear-to-nbest`
Kaldi binary this uses
Parameters
----------
args: :class:`~montreal_forced_aligner.alignment.multiprocessing.AlignArguments`
Arguments for the function
"""
def __init__(self, args: GeneratePronunciationsArguments):
super().__init__(args)
self.aligner = args.aligner
self.lexicon_compilers = args.lexicon_compilers
self.for_g2p = args.for_g2p
self.silence_words = set()
def _process_pronunciations(
self, word_pronunciations: typing.List[typing.Tuple[str, str]]
) -> PronunciationProbabilityCounter:
"""
Process an utterance's pronunciations and extract relevant count information
Parameters
----------
word_pronunciations: list[tuple[str, tuple[str, ...]]]
List of tuples containing the word integer ID and a list of the integer IDs of the phones
"""
counter = PronunciationProbabilityCounter()
word_pronunciations = [("<s>", "")] + word_pronunciations + [("</s>", "")]
for i, w_p in enumerate(word_pronunciations):
if i != 0:
word = word_pronunciations[i - 1][0]
if word in self.silence_words:
counter.silence_before_counts[w_p] += 1
else:
counter.non_silence_before_counts[w_p] += 1
silence_check = w_p[0] in self.silence_words
if not silence_check:
counter.word_pronunciation_counts[w_p[0]][w_p[1]] += 1
if i != len(word_pronunciations) - 1:
word = word_pronunciations[i + 1][0]
if word in self.silence_words:
counter.silence_following_counts[w_p] += 1
if i != len(word_pronunciations) - 2:
next_w_p = word_pronunciations[i + 2]
counter.ngram_counts[w_p, next_w_p]["silence"] += 1
else:
next_w_p = word_pronunciations[i + 1]
counter.non_silence_following_counts[w_p] += 1
counter.ngram_counts[w_p, next_w_p]["non_silence"] += 1
return counter
def _run(self) -> None:
"""Run the function"""
with self.session() as session:
job = (
session.query(Job)
.options(joinedload(Job.corpus, innerjoin=True), subqueryload(Job.dictionaries))
.filter(Job.id == self.job_name)
.first()
)
workflow: CorpusWorkflow = (
session.query(CorpusWorkflow)
.filter(CorpusWorkflow.current == True) # noqa
.first()
)
silence_words = session.query(Word.word).filter(Word.word_type == WordType.silence)
self.silence_words.update(x for x, in silence_words)
for d in job.training_dictionaries:
ali_path = job.construct_path(workflow.working_directory, "ali", "ark", d.name)
if not os.path.exists(ali_path):
continue
if self.lexicon_compilers and d.id in self.lexicon_compilers:
lexicon_compiler = self.lexicon_compilers[d.id]
else:
lexicon_compiler = d.lexicon_compiler
words_path = job.construct_path(workflow.working_directory, "words", "ark", d.name)
alignment_archive = AlignmentArchive(ali_path, words_file_name=words_path)
for alignment in alignment_archive:
intervals = alignment.generate_ctm(
self.aligner.transition_model, lexicon_compiler.phone_table
)
utterance = int(alignment.utterance_id.split("-")[-1])
ctm = lexicon_compiler.phones_to_pronunciations(alignment.words, intervals)
word_pronunciations = []
for wi in ctm.word_intervals:
label = wi.label
pronunciation = wi.pronunciation
if label.startswith(d.cutoff_word[:-1]):
label = d.cutoff_word
if pronunciation != d.oov_phone:
pronunciation = "cutoff_model"
word_pronunciations.append((label, pronunciation))
if self.for_g2p:
phones = []
for i, x in enumerate(word_pronunciations):
if i > 0 and (
x[0].startswith(d.clitic_marker)
or word_pronunciations[i - 1][0].endswith(d.clitic_marker)
):
phones.pop(-1)
else:
phones.append(WORD_BEGIN_SYMBOL)
phones.extend(x[1].split())
phones.append(WORD_END_SYMBOL)
self.callback((d.id, utterance, " ".join(phones)))
else:
self.callback((d.id, self._process_pronunciations(word_pronunciations)))
[docs]
class ExportTextGridProcessWorker(mp.Process):
"""
Multiprocessing worker for exporting TextGrids
See Also
--------
:meth:`.CorpusAligner.collect_alignments`
Main function that runs this worker in parallel
Parameters
----------
for_write_queue: :class:`~multiprocessing.Queue`
Input queue of files to export
stopped: :class:`~multiprocessing.Event`
Stop check for processing
finished_adding: :class:`~multiprocessing.Event`
Input signal that all jobs have been added and no more new ones will come in
textgrid_errors: dict[str, str]
Dictionary for storing errors encountered
arguments: :class:`~montreal_forced_aligner.alignment.multiprocessing.ExportTextGridArguments`
Arguments to pass to the TextGrid export function
exported_file_count: :class:`~montreal_forced_aligner.utils.Counter`
Counter for exported files
"""
def __init__(
self,
db_string: str,
for_write_queue: mp.Queue,
return_queue: mp.Queue,
stopped: mp.Event,
finished_adding: mp.Event,
export_frame_shift: float,
cleanup_textgrids: bool,
clitic_marker: str,
output_directory: Path,
output_format: str,
include_original_text: bool,
):
super().__init__()
self.db_string = db_string
self.for_write_queue = for_write_queue
self.return_queue = return_queue
self.stopped = stopped
self.finished_adding = finished_adding
self.finished_processing = mp.Event()
self.output_directory = output_directory
self.output_format = output_format
self.export_frame_shift = export_frame_shift
self.include_original_text = include_original_text
self.cleanup_textgrids = cleanup_textgrids
self.clitic_marker = clitic_marker
[docs]
def run(self) -> None:
"""Run the exporter function"""
db_engine = sqlalchemy.create_engine(self.db_string)
with sqlalchemy.orm.Session(db_engine) as session:
while True:
try:
(file_batch) = self.for_write_queue.get(timeout=1)
except Empty:
if self.finished_adding.is_set():
self.finished_processing.set()
break
continue
if self.stopped.is_set():
continue
try:
for output_path in construct_textgrid_output(
session,
file_batch,
self.cleanup_textgrids,
self.clitic_marker,
self.output_directory,
self.export_frame_shift,
self.output_format,
self.include_original_text,
):
self.return_queue.put(1)
except Exception:
exc_type, exc_value, exc_traceback = sys.exc_info()
self.return_queue.put(
AlignmentExportError(
output_path,
traceback.format_exception(exc_type, exc_value, exc_traceback),
)
)
self.stopped.set()