Source code for montreal_forced_aligner.alignment.multiprocessing

"""
Alignment multiprocessing functions
-----------------------------------

"""
from __future__ import annotations

import collections
import json
import logging
import multiprocessing as mp
import os
import re
import statistics
import subprocess
import sys
import traceback
import typing
from pathlib import Path
from queue import Empty
from typing import TYPE_CHECKING, Dict, List, Union

import numpy as np
import pynini
import pywrapfst
import sqlalchemy
from pynini.lib import rewrite
from sqlalchemy.orm import Session, joinedload, selectinload, subqueryload

from montreal_forced_aligner.corpus.features import (
    compute_mfcc_process,
    compute_pitch_process,
    compute_transform_process,
)
from montreal_forced_aligner.data import (
    CtmInterval,
    MfaArguments,
    PhoneType,
    PronunciationProbabilityCounter,
    WordCtmInterval,
    WordType,
)
from montreal_forced_aligner.db import (
    CorpusWorkflow,
    DictBundle,
    File,
    Job,
    Phone,
    PhoneInterval,
    Pronunciation,
    SoundFile,
    Speaker,
    Utterance,
    Word,
)
from montreal_forced_aligner.exceptions import AlignmentExportError, FeatureGenerationError
from montreal_forced_aligner.helper import align_pronunciations, mfa_open, split_phone_position
from montreal_forced_aligner.textgrid import (
    construct_output_path,
    construct_output_tiers,
    export_textgrid,
)
from montreal_forced_aligner.utils import (
    Counter,
    KaldiFunction,
    Stopped,
    parse_ctm_output,
    read_feats,
    thirdparty_binary,
)

if TYPE_CHECKING:
    from dataclasses import dataclass

    from montreal_forced_aligner.abc import MetaDict
else:
    from dataclassy import dataclass


__all__ = [
    "AlignmentExtractionFunction",
    "ExportTextGridProcessWorker",
    "AlignmentExtractionArguments",
    "ExportTextGridArguments",
    "AlignFunction",
    "AnalyzeAlignmentsFunction",
    "AlignArguments",
    "AnalyzeAlignmentsArguments",
    "AccStatsFunction",
    "AccStatsArguments",
    "compile_information_func",
    "CompileInformationArguments",
    "CompileTrainGraphsFunction",
    "CompileTrainGraphsArguments",
    "GeneratePronunciationsArguments",
    "GeneratePronunciationsFunction",
]

logger = logging.getLogger("mfa")


def phones_to_prons(
    text: str,
    intervals: List[CtmInterval],
    align_lexicon_fst: pynini.Fst,
    word_pronunciations: typing.Dict[str, typing.Set[str]],
    word_symbol_table: pywrapfst.SymbolTableView,
    phone_symbol_table: pywrapfst.SymbolTableView,
    optional_silence_phone: str,
    transcription: bool = False,
    oov_phone: str = None,
    oov_word: str = None,
    use_g2p: bool = False,
    silence_words: typing.Set[str] = None,
    position_dependent_phones=False,
):
    if use_g2p:
        words = [x.replace(" ", "") for x in text.split("<space>")]
    else:
        words = text.split()
    word_begin = "#1"
    word_end = "#2"
    word_begin_symbol = phone_symbol_table.find(word_begin)
    word_end_symbol = phone_symbol_table.find(word_end)
    if use_g2p:
        kaldi_text = text
    else:
        kaldi_text = " ".join([x if word_symbol_table.member(x) else oov_word for x in words])
    acceptor = pynini.accep(kaldi_text, token_type=word_symbol_table)
    phone_to_word = pynini.compose(align_lexicon_fst, acceptor)
    phone_fst = pynini.Fst()
    current_state = phone_fst.add_state()
    phone_fst.set_start(current_state)
    for p in intervals:
        next_state = phone_fst.add_state()
        symbol = phone_symbol_table.find(p.label)
        phone_fst.add_arc(
            current_state,
            pywrapfst.Arc(
                symbol, symbol, pywrapfst.Weight.one(phone_fst.weight_type()), next_state
            ),
        )
        current_state = next_state
    if transcription:
        if intervals[-1].label == optional_silence_phone:
            state = current_state - 1
        else:
            state = current_state
        phone_to_word_state = phone_to_word.num_states() - 1
        for i in range(phone_symbol_table.num_symbols()):
            if phone_symbol_table.find(i) == "<eps>":
                continue
            if phone_symbol_table.find(i).startswith("#"):
                continue
            phone_fst.add_arc(
                state,
                pywrapfst.Arc(
                    phone_symbol_table.find("<eps>"),
                    i,
                    pywrapfst.Weight.one(phone_fst.weight_type()),
                    state,
                ),
            )

            phone_to_word.add_arc(
                phone_to_word_state,
                pywrapfst.Arc(
                    i,
                    phone_symbol_table.find("<eps>"),
                    pywrapfst.Weight.one(phone_fst.weight_type()),
                    phone_to_word_state,
                ),
            )
    for s in range(current_state + 1):
        phone_fst.add_arc(
            s,
            pywrapfst.Arc(
                word_end_symbol, word_end_symbol, pywrapfst.Weight.one(phone_fst.weight_type()), s
            ),
        )
        phone_fst.add_arc(
            s,
            pywrapfst.Arc(
                word_begin_symbol,
                word_begin_symbol,
                pywrapfst.Weight.one(phone_fst.weight_type()),
                s,
            ),
        )

    phone_fst.set_final(current_state, pywrapfst.Weight.one(phone_fst.weight_type()))
    phone_fst.arcsort("olabel")

    lattice = pynini.compose(phone_fst, phone_to_word)
    try:
        path_string = pynini.shortestpath(lattice).project("input").string(phone_symbol_table)
    except Exception:
        logger.debug("For the text and intervals:")
        logger.debug(text)
        logger.debug(kaldi_text)
        logger.debug([x.label for x in intervals])
        logger.debug("There was an issue composing word and phone FSTs")
        logger.debug("PHONE FST:")
        phone_fst.set_input_symbols(phone_symbol_table)
        phone_fst.set_output_symbols(phone_symbol_table)
        logger.debug(phone_fst)
        logger.debug("PHONE_TO_WORD FST:")
        phone_to_word.set_input_symbols(phone_symbol_table)
        phone_to_word.set_output_symbols(word_symbol_table)
        logger.debug(phone_to_word)
        raise
    path_string = re.sub(f" {word_end}$", "", path_string)
    path_string = path_string.replace(f"{word_end} {word_begin}", word_begin)
    path_string = path_string.replace(f"{word_end}", word_begin)
    path_string = re.sub(f"^{word_begin} ", "", path_string)
    word_splits = [x for x in re.split(rf" ?{word_begin} ?", path_string) if x]
    if position_dependent_phones:
        word_splits = [
            " ".join(split_phone_position(y)[0] for y in x.split()) for x in word_splits
        ]
    pronunciations = align_pronunciations(
        words,
        word_splits,
        oov_phone,
        optional_silence_phone,
        sorted(silence_words)[0],
        word_pronunciations,
    )
    return pronunciations


[docs] @dataclass class GeneratePronunciationsArguments(MfaArguments): """ Arguments for :func:`~montreal_forced_aligner.alignment.multiprocessing.GeneratePronunciationsFunction` Parameters ---------- job_name: int Integer ID of the job db_string: str String for database connections log_path: :class:`~pathlib.Path` Path to save logging information during the run text_int_paths: dict[int, Path] Per dictionary text SCP paths ali_paths: dict[int, Path] Per dictionary alignment paths model_path: :class:`~pathlib.Path` Acoustic model path for_g2p: bool Flag for training a G2P model with acoustic information """ model_path: Path for_g2p: bool
[docs] @dataclass class AlignmentExtractionArguments(MfaArguments): """ Arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.AlignmentExtractionFunction` Parameters ---------- job_name: int Integer ID of the job db_string: str String for database connections log_path: :class:`~pathlib.Path` Path to save logging information during the run model_path: :class:`~pathlib.Path` Acoustic model path frame_shift: float Frame shift in seconds ali_paths: dict[int, Path] Per dictionary alignment paths text_int_paths: dict[int, Path] Per dictionary text SCP paths phone_symbol_path: :class:`~pathlib.Path` Path to phone symbols table score_options: dict[str, Any] Options for Kaldi functions """ model_path: Path frame_shift: float phone_symbol_path: Path score_options: MetaDict confidence: bool transcription: bool
[docs] @dataclass class ExportTextGridArguments(MfaArguments): """ Arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.ExportTextGridProcessWorker` Parameters ---------- job_name: int Integer ID of the job db_string: str String for database connections log_path: :class:`~pathlib.Path` Path to save logging information during the run export_frame_shift: float Frame shift in seconds cleanup_textgrids: bool Flag to cleanup silences and recombine words clitic_marker: str Marker indicating clitics output_directory: :class:`~pathlib.Path` Directory for exporting output_format: str Format to export include_original_text: bool Flag for including original unnormalized text as a tier """ export_frame_shift: float cleanup_textgrids: bool clitic_marker: str output_directory: Path output_format: str include_original_text: bool
[docs] @dataclass class CompileInformationArguments(MfaArguments): """ Arguments for :func:`~montreal_forced_aligner.alignment.multiprocessing.compile_information_func` Parameters ---------- job_name: int Integer ID of the job db_string: str String for database connections log_path: :class:`~pathlib.Path` Path to save logging information during the run align_log_path: :class:`~pathlib.Path` Path to log file for parsing """ align_log_path: Path
[docs] @dataclass class CompileTrainGraphsArguments(MfaArguments): """ Arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.CompileTrainGraphsFunction` Parameters ---------- job_name: int Integer ID of the job db_string: str String for database connections log_path: :class:`~pathlib.Path` Path to save logging information during the run tree_path: :class:`~pathlib.Path` Path to tree file model_path: :class:`~pathlib.Path` Path to model file use_g2p: bool Flag for whether acoustic model uses g2p """ tree_path: Path model_path: Path use_g2p: bool
[docs] @dataclass class AlignArguments(MfaArguments): """ Arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.AlignFunction` Parameters ---------- job_name: int Integer ID of the job db_string: str String for database connections log_path: :class:`~pathlib.Path` Path to save logging information during the run model_path: :class:`~pathlib.Path` Path to model file align_options: dict[str, Any] Alignment options feature_options: dict[str, Any] Feature options confidence: bool Flag for outputting confidence """ model_path: Path align_options: MetaDict feature_options: MetaDict confidence: bool
@dataclass class AnalyzeAlignmentsArguments(MfaArguments): """ Arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.AnalyzeAlignmentsFunction` Parameters ---------- job_name: int Integer ID of the job db_string: str String for database connections log_path: :class:`~pathlib.Path` Path to save logging information during the run model_path: :class:`~pathlib.Path` Path to model file align_options: dict[str, Any] Alignment options """ model_path: Path align_options: MetaDict
[docs] @dataclass class FineTuneArguments(MfaArguments): """ Arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.AlignFunction` Parameters ---------- job_name: int Integer ID of the job db_string: str String for database connections log_path: :class:`~pathlib.Path` Path to save logging information during the run tree_path: :class:`~pathlib.Path` Path to tree file model_path: :class:`~pathlib.Path` Path to model file frame_shift: int Frame shift in ms mfcc_options: dict[str, Any] MFCC computation options pitch_options: dict[str, Any] Pitch computation options align_options: dict[str, Any] Alignment options position_dependent_phones: bool Flag for whether to use position dependent phones grouped_phones: dict[str, list[str]] Grouped lists of phones """ phone_symbol_table_path: Path disambiguation_symbols_int_path: Path tree_path: Path model_path: Path frame_shift: int mfcc_options: MetaDict pitch_options: MetaDict lda_options: MetaDict align_options: MetaDict position_dependent_phones: bool grouped_phones: Dict[str, List[str]]
[docs] @dataclass class PhoneConfidenceArguments(MfaArguments): """ Arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.AlignFunction` Parameters ---------- job_name: int Integer ID of the job db_string: str String for database connections log_path: :class:`~pathlib.Path` Path to save logging information during the run model_path: :class:`~pathlib.Path` Path to model file phone_pdf_counts_path: :class:`~pathlib.Path` Path to output PDF counts feature_strings: dict[int, str] Mapping of dictionaries to feature generation strings """ model_path: Path phone_pdf_counts_path: Path feature_strings: Dict[int, str]
[docs] @dataclass class AccStatsArguments(MfaArguments): """ Arguments for :class:`~montreal_forced_aligner.alignment.multiprocessing.AccStatsFunction` Parameters ---------- job_name: int Integer ID of the job db_string: str String for database connections log_path: :class:`~pathlib.Path` Path to save logging information during the run dictionaries: list[int] List of dictionary ids feature_strings: dict[int, str] Mapping of dictionaries to feature generation strings ali_paths: dict[int, Path] Per dictionary alignment paths acc_paths: dict[int, Path] Per dictionary accumulated stats paths model_path: :class:`~pathlib.Path` Path to model file """ dictionaries: List[int] feature_strings: Dict[int, str] ali_paths: Dict[int, Path] acc_paths: Dict[int, Path] model_path: Path
[docs] class CompileTrainGraphsFunction(KaldiFunction): """ Multiprocessing function to compile training graphs See Also -------- :meth:`.AlignMixin.compile_train_graphs` Main function that calls this function in parallel :meth:`.AlignMixin.compile_train_graphs_arguments` Job method for generating arguments for this function :kaldi_src:`compile-train-graphs` Relevant Kaldi binary Parameters ---------- args: :class:`~montreal_forced_aligner.alignment.multiprocessing.CompileTrainGraphsArguments` Arguments for the function """ progress_pattern = re.compile( r"^LOG.*succeeded for (?P<succeeded>\d+) graphs, failed for (?P<failed>\d+)" ) def __init__(self, args: CompileTrainGraphsArguments): super().__init__(args) self.tree_path = args.tree_path self.model_path = args.model_path self.use_g2p = args.use_g2p def _run(self) -> typing.Generator[typing.Tuple[int, int]]: """Run the function""" with mfa_open(self.log_path, "w") as log_file, Session(self.db_engine()) as session: job = ( session.query(Job) .options(joinedload(Job.corpus, innerjoin=True), subqueryload(Job.dictionaries)) .filter(Job.id == self.job_name) .first() ) workflow: CorpusWorkflow = ( session.query(CorpusWorkflow) .filter(CorpusWorkflow.current == True) # noqa .first() ) tree_proc = subprocess.Popen( [thirdparty_binary("tree-info"), self.tree_path], encoding="utf8", stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) stdout, _ = tree_proc.communicate() context_width = 1 central_pos = 0 for line in stdout.split("\n"): text = line.strip().split(" ") if text[0] == "context-width": context_width = int(text[1]) elif text[0] == "central-position": central_pos = int(text[1]) out_disambig = os.path.join(workflow.working_directory, f"{self.job_name}.disambig") ilabels_temp = os.path.join(workflow.working_directory, f"{self.job_name}.ilabels") clg_path = os.path.join(workflow.working_directory, f"{self.job_name}.clg.temp") ha_out_disambig = os.path.join( workflow.working_directory, f"{self.job_name}.ha_out_disambig.temp" ) text_int_paths = job.per_dictionary_text_int_scp_paths batch_size = 1000 if self.use_g2p: from montreal_forced_aligner.g2p.generator import threshold_lattice_to_dfa for d in job.dictionaries: log_file.write(f"Compiling graphs for {d.name} ({d.id})...\n") fst = pynini.Fst.read(d.lexicon_fst_path) words = d.word_mapping if self.use_g2p: token_type = pywrapfst.SymbolTable.read_text(d.grapheme_symbol_table_path) text_column = Utterance.normalized_character_text else: token_type = pywrapfst.SymbolTable.read_text(d.words_symbol_path) text_column = Utterance.normalized_text fst.invert() utterances = ( session.query(Utterance.kaldi_id, text_column) .join(Utterance.speaker) .filter(Utterance.ignored == False) # noqa .filter(text_column != "") .filter(Utterance.job_id == self.job_name) .filter(Speaker.dictionary_id == d.id) .order_by(Utterance.kaldi_id) ) fst_ark_path = job.construct_path( workflow.working_directory, "fsts", "ark", d.id ) with mfa_open(fst_ark_path, "wb") as fst_output_file: for utt_id, full_text in utterances: try: if self.use_g2p: lattice = rewrite.rewrite_lattice(full_text, fst, token_type) lattice = threshold_lattice_to_dfa(lattice, 2.0) else: text = " ".join( [ x if x in words else d.oov_word for x in full_text.split() ] ) a = pynini.accep(text, token_type=token_type) lattice = rewrite.rewrite_lattice(a, fst) lattice.invert() input = lattice.write_to_string() except pynini.lib.rewrite.Error: log_file.write(f'Error composing "{full_text}"\n') log_file.flush() continue clg_compose_proc = subprocess.Popen( [ thirdparty_binary("fstcomposecontext"), f"--context-size={context_width}", f"--central-position={central_pos}", f"--read-disambig-syms={d.disambiguation_symbols_int_path}", f"--write-disambig-syms={out_disambig}", ilabels_temp, "-", "-", ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=log_file, env=os.environ, ) clg_sort_proc = subprocess.Popen( [ thirdparty_binary("fstarcsort"), "--sort_type=ilabel", "-", clg_path, ], stdin=clg_compose_proc.stdout, stderr=log_file, env=os.environ, ) clg_compose_proc.stdin.write(input) clg_compose_proc.stdin.flush() clg_compose_proc.stdin.close() clg_sort_proc.communicate() make_h_proc = subprocess.Popen( [ thirdparty_binary("make-h-transducer"), f"--disambig-syms-out={ha_out_disambig}", ilabels_temp, self.tree_path, self.model_path, ], stderr=log_file, stdout=subprocess.PIPE, env=os.environ, ) hclg_compose_proc = subprocess.Popen( [thirdparty_binary("fsttablecompose"), "-", clg_path, "-"], stderr=log_file, stdin=make_h_proc.stdout, stdout=subprocess.PIPE, env=os.environ, ) hclg_determinize_proc = subprocess.Popen( [thirdparty_binary("fstdeterminizestar"), "--use-log=true"], stdin=hclg_compose_proc.stdout, stdout=subprocess.PIPE, stderr=log_file, env=os.environ, ) hclg_rmsymbols_proc = subprocess.Popen( [thirdparty_binary("fstrmsymbols"), ha_out_disambig], stdin=hclg_determinize_proc.stdout, stdout=subprocess.PIPE, stderr=log_file, env=os.environ, ) hclg_rmeps_proc = subprocess.Popen( [thirdparty_binary("fstrmepslocal")], stdin=hclg_rmsymbols_proc.stdout, stdout=subprocess.PIPE, stderr=log_file, env=os.environ, ) hclg_minimize_proc = subprocess.Popen( [thirdparty_binary("fstminimizeencoded")], stdin=hclg_rmeps_proc.stdout, stdout=subprocess.PIPE, stderr=log_file, env=os.environ, ) hclg_self_loop_proc = subprocess.Popen( [ thirdparty_binary("add-self-loops"), "--self-loop-scale=0.1", "--reorder=true", self.model_path, "-", "-", ], stdin=hclg_minimize_proc.stdout, stdout=subprocess.PIPE, stderr=log_file, env=os.environ, ) stdout, _ = hclg_self_loop_proc.communicate() self.check_call(hclg_minimize_proc) fst_output_file.write(utt_id.encode("utf8") + b" ") fst_output_file.write(stdout) yield 1, 0 else: for d in job.dictionaries: log_file.write(f"Compiling graphs for {d}") fst_ark_path = job.construct_path( workflow.working_directory, "fsts", "ark", d.id ) text_path = text_int_paths[d.id] proc = subprocess.Popen( [ thirdparty_binary("compile-train-graphs"), f"--read-disambig-syms={d.disambiguation_symbols_int_path}", f"--batch-size={batch_size}", self.tree_path, self.model_path, d.lexicon_fst_path, f"ark,s,cs:{text_path}", f"ark:{fst_ark_path}", ], stderr=subprocess.PIPE, encoding="utf8", env=os.environ, ) for line in proc.stderr: log_file.write(line) log_file.flush() m = self.progress_pattern.match(line.strip()) if m: yield int(m.group("succeeded")), int(m.group("failed")) self.check_call(proc)
[docs] class AccStatsFunction(KaldiFunction): """ Multiprocessing function for accumulating stats in GMM training. See Also -------- :meth:`.AcousticModelTrainingMixin.acc_stats` Main function that calls this function in parallel :meth:`.AcousticModelTrainingMixin.acc_stats_arguments` Job method for generating arguments for this function :kaldi_src:`gmm-acc-stats-ali` Relevant Kaldi binary Parameters ---------- args: :class:`~montreal_forced_aligner.alignment.multiprocessing.AccStatsArguments` Arguments for the function """ progress_pattern = re.compile( r"^LOG \(gmm-acc-stats-ali.* Processed (?P<utterances>\d+) utterances;.*" ) done_pattern = re.compile( r"^LOG \(gmm-acc-stats-ali.*Done (?P<utterances>\d+) files, (?P<errors>\d+) with errors.$" ) def __init__(self, args: AccStatsArguments): super().__init__(args) self.dictionaries = args.dictionaries self.feature_strings = args.feature_strings self.model_path = args.model_path self.ali_paths = args.ali_paths self.acc_paths = args.acc_paths def _run(self) -> typing.Generator[typing.Tuple[int, int]]: """Run the function""" with mfa_open(self.log_path, "w") as log_file: for dict_id in self.dictionaries: processed_count = 0 acc_proc = subprocess.Popen( [ thirdparty_binary("gmm-acc-stats-ali"), self.model_path, self.feature_strings[dict_id], f"ark,s,cs:{self.ali_paths[dict_id]}", self.acc_paths[dict_id], ], stderr=subprocess.PIPE, encoding="utf8", env=os.environ, ) for line in acc_proc.stderr: log_file.write(line) m = self.progress_pattern.match(line.strip()) if m: now_processed = int(m.group("utterances")) progress_update = now_processed - processed_count processed_count = now_processed yield progress_update, 0 else: m = self.done_pattern.match(line.strip()) if m: now_processed = int(m.group("utterances")) progress_update = now_processed - processed_count yield progress_update, int(m.group("errors")) self.check_call(acc_proc)
[docs] class AlignFunction(KaldiFunction): """ Multiprocessing function for alignment. See Also -------- :meth:`.AlignMixin.align_utterances` Main function that calls this function in parallel :meth:`.AlignMixin.align_arguments` Job method for generating arguments for this function :kaldi_src:`align-gmm-compiled` Relevant Kaldi binary :kaldi_src:`gmm-boost-silence` Relevant Kaldi binary Parameters ---------- args: :class:`~montreal_forced_aligner.alignment.multiprocessing.AlignArguments` Arguments for the function """ progress_pattern = re.compile( r"^LOG.*Log-like per frame for utterance (?P<utterance>.*) is (?P<loglike>[-\d.]+) over (?P<num_frames>\d+) frames." ) def __init__(self, args: AlignArguments): super().__init__(args) self.model_path = args.model_path self.align_options = args.align_options self.feature_options = args.feature_options self.confidence = args.confidence def _run(self) -> typing.Generator[typing.Tuple[int, float]]: """Run the function""" with mfa_open(self.log_path, "w") as log_file, Session(self.db_engine()) as session: job: Job = ( session.query(Job) .options(joinedload(Job.corpus, innerjoin=True), subqueryload(Job.dictionaries)) .filter(Job.id == self.job_name) .first() ) workflow: CorpusWorkflow = ( session.query(CorpusWorkflow) .filter(CorpusWorkflow.current == True) # noqa .first() ) for d in job.dictionaries: dict_id = d.id word_symbols_path = d.words_symbol_path feature_string = job.construct_feature_proc_string( workflow.working_directory, dict_id, self.feature_options["uses_splices"], self.feature_options["splice_left_context"], self.feature_options["splice_right_context"], self.feature_options["uses_speaker_adaptation"], ) fst_path = job.construct_path(workflow.working_directory, "fsts", "ark", dict_id) fmllr_path = job.construct_path( workflow.working_directory, "trans", "ark", dict_id ) ali_path = job.construct_path(workflow.working_directory, "ali", "ark", dict_id) like_path = job.construct_path(workflow.working_directory, "like", "ark", dict_id) if ( self.confidence and self.feature_options["uses_speaker_adaptation"] and os.path.exists(fmllr_path) ): ali_path = job.construct_path( workflow.working_directory, "lat", "ark", dict_id ) com = [ thirdparty_binary("gmm-latgen-faster"), f"--acoustic-scale={self.align_options['acoustic_scale']}", f"--beam={self.align_options['beam']}", f"--max-active={self.align_options['max_active']}", f"--lattice-beam={self.align_options['lattice_beam']}", f"--word-symbol-table={word_symbols_path}", "--allow-partial=true", self.model_path, f"ark,s,cs:{fst_path}", feature_string, f"ark:{ali_path}", ] align_proc = subprocess.Popen( com, stderr=subprocess.PIPE, env=os.environ, encoding="utf8" ) process_stream = align_proc.stderr else: com = [ thirdparty_binary("gmm-align-compiled"), f"--transition-scale={self.align_options['transition_scale']}", f"--acoustic-scale={self.align_options['acoustic_scale']}", f"--self-loop-scale={self.align_options['self_loop_scale']}", f"--beam={self.align_options['beam']}", f"--retry-beam={self.align_options['retry_beam']}", "--careful=false", f"--write-per-frame-acoustic-loglikes=ark:{like_path}", "-", f"ark,s,cs:{fst_path}", feature_string, f"ark:{ali_path}", "ark,t:-", ] boost_proc = subprocess.Popen( [ thirdparty_binary("gmm-boost-silence"), f"--boost={self.align_options['boost_silence']}", self.align_options["optional_silence_csl"], self.model_path, "-", ], stderr=log_file, stdout=subprocess.PIPE, env=os.environ, ) align_proc = subprocess.Popen( com, stdout=subprocess.PIPE, stderr=log_file, encoding="utf8", stdin=boost_proc.stdout, env=os.environ, ) process_stream = align_proc.stdout no_feature_count = 0 for line in process_stream: if re.search("No features for utterance", line): no_feature_count += 1 line = line.strip() if ( self.confidence and self.feature_options["uses_speaker_adaptation"] and os.path.exists(fmllr_path) ): log_file.write(line + "\n") m = self.progress_pattern.match(line) if m: utterance = m.group("utterance") u_id = int(utterance.split("-")[-1]) yield u_id, float(m.group("loglike")) else: utterance, log_likelihood = line.split() u_id = int(utterance.split("-")[-1]) yield u_id, float(log_likelihood) if no_feature_count: align_proc.wait() raise FeatureGenerationError( f"There was an issue in feature generation for {no_feature_count} utterances. " f"This can be caused by version incompatibilities between MFA and the model, " f"in which case you should re-download or re-train your model, " f"or downgrade MFA to the version that the model was trained on." ) self.check_call(align_proc)
class AnalyzeAlignmentsFunction(KaldiFunction): """ Multiprocessing function for analyzing alignments. See Also -------- :meth:`.CorpusAligner.analyze_alignments` Main function that calls this function in parallel :meth:`.CorpusAligner.calculate_speech_post_arguments` Job method for generating arguments for this function :kaldi_src:`lattice-to-post` Relevant Kaldi binary :kaldi_src:`weight-silence-post` Relevant Kaldi binary Parameters ---------- args: :class:`~montreal_forced_aligner.alignment.multiprocessing.CalculateSpeechPostArguments` Arguments for the function """ progress_pattern = re.compile( r"^LOG.*Log-like per frame for utterance (?P<utterance>.*) is (?P<loglike>[-\d.]+) over (?P<num_frames>\d+) frames." ) def __init__(self, args: AnalyzeAlignmentsArguments): super().__init__(args) self.model_path = args.model_path self.align_options = args.align_options def _run(self) -> typing.Generator[typing.Tuple[int, float]]: """Run the function""" with Session(self.db_engine()) as session: job: Job = ( session.query(Job) .options(joinedload(Job.corpus, innerjoin=True), subqueryload(Job.dictionaries)) .filter(Job.id == self.job_name) .first() ) workflow = ( session.query(CorpusWorkflow) .filter(CorpusWorkflow.current == True) # noqa .first() ) phones = { k: (m, sd) for k, m, sd in session.query( Phone.id, Phone.mean_duration, Phone.sd_duration ).filter( Phone.phone_type.in_([PhoneType.non_silence, PhoneType.oov]), Phone.sd_duration != None, # noqa Phone.sd_duration != 0, ) } query = session.query(Utterance).filter( Utterance.job_id == job.id, Utterance.alignment_log_likelihood != None # noqa ) for utterance in query: phone_intervals = ( session.query(PhoneInterval) .join(PhoneInterval.phone) .filter( PhoneInterval.utterance_id == utterance.id, PhoneInterval.workflow_id == workflow.id, Phone.id.in_(list(phones.keys())), ) .all() ) if not phone_intervals: continue interval_count = len(phone_intervals) log_like_sum = 0 duration_zscore_sum = 0 for pi in phone_intervals: log_like_sum += pi.phone_goodness m, sd = phones[pi.phone_id] duration_zscore_sum += abs((pi.duration - m) / sd) utterance_speech_log_likelihood = log_like_sum / interval_count utterance_duration_deviation = duration_zscore_sum / interval_count yield utterance.id, utterance_speech_log_likelihood, utterance_duration_deviation
[docs] class FineTuneFunction(KaldiFunction): """ Multiprocessing function for fine tuning alignment. Parameters ---------- args: :class:`~montreal_forced_aligner.alignment.multiprocessing.FineTuneArguments` Arguments for the function """ def __init__(self, args: FineTuneArguments): super().__init__(args) self.frame_shift = args.frame_shift self.scaling_factor = 10 self.frame_shift_seconds = round(self.frame_shift / 1000, 3) self.new_frame_shift = int(self.frame_shift / self.scaling_factor) self.new_frame_shift_seconds = round(self.new_frame_shift / 1000, 4) self.feature_padding_factor = 4 self.padding = round(self.frame_shift_seconds, 3) self.tree_path = args.tree_path self.model_path = args.model_path self.mfcc_options = args.mfcc_options self.mfcc_options["frame-shift"] = self.new_frame_shift self.mfcc_options["snip-edges"] = False self.pitch_options = args.pitch_options self.pitch_options["frame-shift"] = self.new_frame_shift self.pitch_options["snip-edges"] = False self.lda_options = args.lda_options self.align_options = args.align_options self.grouped_phones = args.grouped_phones self.position_dependent_phones = args.position_dependent_phones self.disambiguation_symbols_int_path = args.disambiguation_symbols_int_path self.segment_begins = {} self.segment_ends = {} self.original_intervals = {} self.utterance_initial_intervals = {} def setup_files( self, session: Session, job: Job, workflow: CorpusWorkflow, dictionary_id: int ): wav_path = job.construct_path( workflow.working_directory, "fine_tune_wav", "scp", dictionary_id ) segment_path = job.construct_path( workflow.working_directory, "fine_tune_segments", "scp", dictionary_id ) feature_segment_path = job.construct_path( workflow.working_directory, "fine_tune_feature_segments", "scp", dictionary_id ) utt2spk_path = job.construct_path( workflow.working_directory, "fine_tune_utt2spk", "scp", dictionary_id ) text_path = job.construct_path( workflow.working_directory, "fine_tune_text", "scp", dictionary_id ) columns = [ PhoneInterval.utterance_id, Phone.kaldi_label, PhoneInterval.id, PhoneInterval.begin, PhoneInterval.end, SoundFile.sox_string, SoundFile.sound_file_path, SoundFile.sample_rate, Utterance.channel, Utterance.speaker_id, Utterance.file_id, ] utterance_ends = { k: v for k, v in session.query(Utterance.id, Utterance.end).filter( Utterance.job_id == self.job_name ) } bn = DictBundle("interval_data", *columns) interval_query = ( session.query(bn) .join(PhoneInterval.phone) .join(PhoneInterval.utterance) .join(Utterance.file) .join(File.sound_file) .filter(Utterance.job_id == self.job_name) .filter(PhoneInterval.workflow_id == workflow.id) .order_by(PhoneInterval.utterance_id, PhoneInterval.begin) ) wav_data = {} utt2spk_data = {} segment_data = {} text_data = {} prev_label = None current_id = None for row in interval_query: data = row.interval_data if current_id is None: current_id = data["utterance_id"] label = data["kaldi_label"] if current_id != data["utterance_id"] or prev_label is None: self.utterance_initial_intervals[data["utterance_id"]] = { "id": data["id"], "begin": data["begin"], "end": data["end"], } prev_label = label current_id = data["utterance_id"] continue boundary_id = f"{data['utterance_id']}-{data['id']}" utt2spk_data[boundary_id] = data["speaker_id"] sox_string = data["sox_string"] if not sox_string: sox_string = f'sox "{data["sound_file_path"]}" -t wav -b 16 -r 16000 - |' wav_data[str(data["file_id"])] = sox_string interval_begin = data["begin"] self.original_intervals[data["id"]] = { "begin": data["begin"], "end": data["end"], "utterance_id": data["utterance_id"], } segment_begin = round(interval_begin - self.padding, 4) feature_segment_begin = round( interval_begin - (self.padding * self.feature_padding_factor), 4 ) if segment_begin < 0: segment_begin = 0 if feature_segment_begin < 0: feature_segment_begin = 0 begin_offset = round(segment_begin - feature_segment_begin, 4) segment_end = round(interval_begin + self.padding, 4) feature_segment_end = round( interval_begin + (self.padding * self.feature_padding_factor), 4 ) if segment_end > utterance_ends[data["utterance_id"]]: segment_end = utterance_ends[data["utterance_id"]] if feature_segment_end > utterance_ends[data["utterance_id"]]: feature_segment_end = utterance_ends[data["utterance_id"]] end_offset = round(segment_end - feature_segment_begin, 4) self.segment_begins[data["id"]] = segment_begin self.segment_ends[data["id"]] = data["end"] segment_data[boundary_id] = ( map( str, [ data["file_id"], f"{feature_segment_begin:.4f}", f"{feature_segment_end:.4f}", data["channel"], ], ), map(str, [boundary_id, f"{begin_offset:.4f}", f"{end_offset:.4f}"]), ) text_data[ boundary_id ] = f"{self.phone_to_group_mapping[prev_label]} {self.phone_to_group_mapping[label]}" prev_label = label with mfa_open(utt2spk_path, "w") as f: for k, v in sorted(utt2spk_data.items()): f.write(f"{k} {v}\n") with mfa_open(wav_path, "w") as f: for k, v in sorted(wav_data.items()): f.write(f"{k} {v}\n") with mfa_open(segment_path, "w") as f, mfa_open(feature_segment_path, "w") as feature_f: for k, v in sorted(segment_data.items()): f.write(f"{k} {' '.join(v[0])}\n") feature_f.write(f"{k} {' '.join(v[1])}\n") with mfa_open(text_path, "w") as f: for k, v in sorted(text_data.items()): f.write(f"{k} {v}\n") def _run(self) -> typing.Generator[typing.Tuple[int, float]]: """Run the function""" with Session(self.db_engine()) as session, mfa_open(self.log_path, "w") as log_file: job = ( session.query(Job) .options(joinedload(Job.corpus, innerjoin=True), subqueryload(Job.dictionaries)) .filter(Job.id == self.job_name) .first() ) workflow: CorpusWorkflow = ( session.query(CorpusWorkflow) .filter(CorpusWorkflow.current == True) # noqa .first() ) reversed_phone_mapping = {} phone_mapping = {} phone_query = session.query(Phone.mapping_id, Phone.id, Phone.kaldi_label) for m_id, p_id, phone in phone_query: reversed_phone_mapping[m_id] = p_id phone_mapping[phone] = m_id lexicon_path = os.path.join(workflow.working_directory, "phone.fst") group_mapping_path = os.path.join(workflow.working_directory, "groups.txt") fst = pynini.Fst() initial_state = fst.add_state() fst.set_start(initial_state) fst.set_final(initial_state, 0) processed = set() if self.position_dependent_phones: self.grouped_phones["silence"] = ["sil", "sil_B", "sil_I", "sil_E", "sil_S"] self.grouped_phones["unknown"] = ["spn", "spn_B", "spn_I", "spn_E", "spn_S"] else: self.grouped_phones["silence"] = ["sil"] self.grouped_phones["unknown"] = ["spn"] group_set = ["<eps>"] + sorted(k for k in self.grouped_phones.keys()) group_mapping = {k: i for i, k in enumerate(group_set)} self.phone_to_group_mapping = {} for k, group in self.grouped_phones.items(): for p in group: self.phone_to_group_mapping[p] = group_mapping[k] fst.add_arc( initial_state, pywrapfst.Arc(phone_mapping[p], group_mapping[k], 0, initial_state), ) processed.update(group) with mfa_open(group_mapping_path, "w") as f: for i, k in group_mapping.items(): f.write(f"{k} {i}\n") for phone, i in phone_mapping.items(): if phone in processed: continue fst.add_arc(initial_state, pywrapfst.Arc(i, i, 0, initial_state)) fst.arcsort("olabel") fst.write(lexicon_path) min_length = round(self.frame_shift_seconds / 3, 4) cmvn_paths = job.per_dictionary_cmvn_scp_paths for d_id in job.dictionary_ids: cmvn_path = cmvn_paths[d_id] wav_path = job.construct_path( workflow.working_directory, "fine_tune_wav", "scp", d_id ) segment_path = job.construct_path( workflow.working_directory, "fine_tune_segments", "scp", d_id ) feature_segment_path = job.construct_path( workflow.working_directory, "fine_tune_feature_segments", "scp", d_id ) utt2spk_path = job.construct_path( workflow.working_directory, "fine_tune_utt2spk", "scp", d_id ) text_path = job.construct_path( workflow.working_directory, "fine_tune_text", "scp", d_id ) pitch_ark_path = job.construct_path( workflow.working_directory, "fine_tune_pitch", "ark", d_id ) mfcc_ark_path = job.construct_path( workflow.working_directory, "fine_tune_mfcc", "ark", d_id ) feats_ark_path = job.construct_path( workflow.working_directory, "fine_tune_feats", "ark", d_id ) fmllr_path = job.construct_path(workflow.working_directory, "trans", "ark", d_id) self.setup_files(session, job, workflow, d_id) fst_ark_path = job.construct_path( workflow.working_directory, "fine_tune_fsts", "ark", d_id ) proc = subprocess.Popen( [ thirdparty_binary("compile-train-graphs"), f"--read-disambig-syms={self.disambiguation_symbols_int_path}", self.tree_path, self.model_path, lexicon_path, f"ark,s,cs:{text_path}", f"ark:{fst_ark_path}", ], stderr=log_file, env=os.environ, ) proc.communicate() seg_proc = subprocess.Popen( [ thirdparty_binary("extract-segments"), f"--min-segment-length={min_length}", f"scp:{wav_path}", segment_path, "ark:-", ], stdout=subprocess.PIPE, stderr=log_file, env=os.environ, ) mfcc_proc = compute_mfcc_process( log_file, wav_path, subprocess.PIPE, self.mfcc_options ) cmvn_proc = subprocess.Popen( [ "apply-cmvn", f"--utt2spk=ark:{utt2spk_path}", f"scp:{cmvn_path}", "ark:-", f"ark:{mfcc_ark_path}", ], env=os.environ, stdin=mfcc_proc.stdout, stderr=log_file, ) use_pitch = self.pitch_options["use-pitch"] or self.pitch_options["use-voicing"] if use_pitch: pitch_proc = compute_pitch_process( log_file, wav_path, subprocess.PIPE, self.pitch_options ) pitch_copy_proc = subprocess.Popen( [ thirdparty_binary("copy-feats"), "--compress=true", "ark:-", f"ark:{pitch_ark_path}", ], stdin=pitch_proc.stdout, stderr=log_file, env=os.environ, ) for line in seg_proc.stdout: mfcc_proc.stdin.write(line) mfcc_proc.stdin.flush() if use_pitch: pitch_proc.stdin.write(line) pitch_proc.stdin.flush() mfcc_proc.stdin.close() if use_pitch: pitch_proc.stdin.close() cmvn_proc.wait() if use_pitch: pitch_copy_proc.wait() if use_pitch: paste_proc = subprocess.Popen( [ thirdparty_binary("paste-feats"), "--length-tolerance=2", f"ark:{mfcc_ark_path}", f"ark:{pitch_ark_path}", f"ark:{feats_ark_path}", ], stderr=log_file, env=os.environ, ) paste_proc.wait() else: feats_ark_path = mfcc_ark_path extract_proc = subprocess.Popen( [ thirdparty_binary("extract-feature-segments"), f"--min-segment-length={min_length}", f"--frame-shift={self.new_frame_shift}", f'--snip-edges={self.mfcc_options["snip-edges"]}', f"ark,s,cs:{feats_ark_path}", feature_segment_path, "ark:-", ], stderr=log_file, stdout=subprocess.PIPE, env=os.environ, ) trans_proc = compute_transform_process( log_file, extract_proc, workflow.lda_mat_path, self.lda_options, fmllr_path=fmllr_path, utt2spk_path=utt2spk_path, ) align_proc = subprocess.Popen( [ thirdparty_binary("gmm-align-compiled"), f"--transition-scale={self.align_options['transition_scale']}", f"--acoustic-scale={self.align_options['acoustic_scale']}", f"--self-loop-scale={self.align_options['self_loop_scale']}", f"--beam={self.align_options['beam']}", f"--retry-beam={self.align_options['retry_beam']}", "--careful=false", self.model_path, f"ark,s,cs:{fst_ark_path}", "ark,s,cs:-", "ark:-", ], stdout=subprocess.PIPE, stderr=log_file, stdin=trans_proc.stdout, env=os.environ, ) ctm_proc = subprocess.Popen( [ thirdparty_binary("ali-to-phones"), "--ctm-output", f"--frame-shift={self.new_frame_shift_seconds}", self.model_path, "ark,s,cs:-", "-", ], stderr=log_file, stdin=align_proc.stdout, stdout=subprocess.PIPE, env=os.environ, encoding="utf8", ) interval_mapping = [] current_utterance = None for boundary_id, ctm_intervals in parse_ctm_output( ctm_proc, reversed_phone_mapping, raw_id=True ): utterance_id, interval_id = boundary_id.split("-") interval_id = int(interval_id) utterance_id = int(utterance_id) if current_utterance is None: current_utterance = utterance_id if current_utterance != utterance_id: interval_mapping = sorted(interval_mapping, key=lambda x: x["id"]) interval_mapping.insert( 0, self.utterance_initial_intervals[current_utterance] ) deletions = [] while True: for i in range(len(interval_mapping) - 1): if interval_mapping[i]["end"] != interval_mapping[i + 1]["begin"]: interval_mapping[i]["end"] = interval_mapping[i + 1]["begin"] new_deletions = [ x["id"] for x in interval_mapping if x["begin"] >= x["end"] ] interval_mapping = [ x for x in interval_mapping if x["id"] not in new_deletions ] deletions.extend(new_deletions) if not new_deletions and all( interval_mapping[i]["end"] == interval_mapping[i + 1]["begin"] for i in range(len(interval_mapping) - 1) ): break yield interval_mapping, deletions interval_mapping = [] current_utterance = utterance_id interval_mapping.append( { "id": interval_id, "begin": round( ctm_intervals[1].begin + self.segment_begins[interval_id], 4 ), "end": self.original_intervals[interval_id]["end"], "label": ctm_intervals[1].label, } ) if interval_mapping: deletions = [] while True: for i in range(len(interval_mapping) - 1): if interval_mapping[i]["end"] != interval_mapping[i + 1]["begin"]: interval_mapping[i]["end"] = interval_mapping[i + 1]["begin"] new_deletions = [ x["id"] for x in interval_mapping if x["begin"] >= x["end"] ] interval_mapping = [ x for x in interval_mapping if x["id"] not in new_deletions ] deletions.extend(new_deletions) if not new_deletions and all( interval_mapping[i]["end"] == interval_mapping[i + 1]["begin"] for i in range(len(interval_mapping) - 1) ): break yield interval_mapping, deletions self.check_call(ctm_proc)
[docs] class PhoneConfidenceFunction(KaldiFunction): """ Multiprocessing function to calculate phone confidence metrics See Also -------- :kaldi_src:`gmm-compute-likes` Relevant Kaldi binary :kaldi_src:`transform-feats` Relevant Kaldi binary Parameters ---------- args: :class:`~montreal_forced_aligner.alignment.multiprocessing.PhoneConfidenceArguments` Arguments for the function """ def __init__(self, args: PhoneConfidenceArguments): super().__init__(args) self.model_path = args.model_path self.phone_pdf_counts_path = args.phone_pdf_counts_path self.feature_strings = args.feature_strings def _run(self) -> typing.Generator[typing.Tuple[int, str]]: """Run the function""" with Session(self.db_engine()) as session: utterances = ( session.query(Utterance) .filter(Utterance.job_id == self.job_name) .options( selectinload(Utterance.phone_intervals).joinedload( PhoneInterval.phone, innerjoin=True ) ) ) utterances = {u.id: (u.begin, u.phone_intervals) for u in utterances} phone_mapping = {p.phone: p.id for p in session.query(Phone)} with mfa_open(self.phone_pdf_counts_path, "r") as f: data = json.load(f) phone_pdf_mapping = collections.defaultdict(collections.Counter) for phone, pdf_counts in data.items(): phone = split_phone_position(phone)[0] for pdf, count in pdf_counts.items(): phone_pdf_mapping[phone][int(pdf)] += count phones = {p: i for i, p in enumerate(sorted(phone_pdf_mapping.keys()))} reversed_phones = {k: v for v, k in phones.items()} for phone, pdf_counts in phone_pdf_mapping.items(): phone_total = sum(pdf_counts.values()) for pdf, count in pdf_counts.items(): phone_pdf_mapping[phone][int(pdf)] = count / phone_total with mfa_open(self.log_path, "w") as log_file: for dict_id in self.feature_strings.keys(): feature_string = self.feature_strings[dict_id] output_proc = subprocess.Popen( [ thirdparty_binary("gmm-compute-likes"), self.model_path, feature_string, "ark,t:-", ], stderr=log_file, stdout=subprocess.PIPE, env=os.environ, ) interval_mappings = [] new_interval_mappings = [] for utterance_id, likelihoods in read_feats(output_proc): phone_likes = np.zeros((likelihoods.shape[0], len(phones))) for i, p in reversed_phones.items(): like = likelihoods[:, [x for x in phone_pdf_mapping[p].keys()]] weight = np.array([x for x in phone_pdf_mapping[p].values()]) phone_likes[:, i] = np.dot(like, weight) top_phone_inds = np.argmax(phone_likes, axis=1) utt_begin, intervals = utterances[utterance_id] for pi in intervals: if pi.phone.phone == "sil": continue frame_begin = int(((pi.begin - utt_begin) * 1000) / 10) frame_end = int(((pi.end - utt_begin) * 1000) / 10) if frame_begin == frame_end: frame_end += 1 frame_end = min(frame_end, top_phone_inds.shape[0]) alternate_labels = collections.Counter() scores = [] for i in range(frame_begin, frame_end): top_phone_ind = top_phone_inds[i] alternate_label = reversed_phones[top_phone_ind] alternate_label = split_phone_position(alternate_label)[0] alternate_labels[alternate_label] += 1 if alternate_label == pi.phone.phone: scores.append(0) else: actual_score = phone_likes[i, phones[pi.phone.phone]] scores.append(phone_likes[i, top_phone_ind] - actual_score) average_score = statistics.mean(scores) alternate_label = max(alternate_labels, key=lambda x: alternate_labels[x]) interval_mappings.append({"id": pi.id, "phone_goodness": average_score}) new_interval_mappings.append( { "begin": pi.begin, "end": pi.end, "utterance_id": pi.utterance_id, "phone_id": phone_mapping[alternate_label], } ) yield interval_mappings interval_mappings = [] self.check_call(output_proc)
[docs] class GeneratePronunciationsFunction(KaldiFunction): """ Multiprocessing function for generating pronunciations See Also -------- :meth:`.DictionaryTrainer.export_lexicons` Main function that calls this function in parallel :meth:`.CorpusAligner.generate_pronunciations_arguments` Job method for generating arguments for this function :kaldi_src:`linear-to-nbest` Kaldi binary this uses Parameters ---------- args: :class:`~montreal_forced_aligner.alignment.multiprocessing.AlignArguments` Arguments for the function """ def __init__(self, args: GeneratePronunciationsArguments): super().__init__(args) self.model_path = args.model_path self.for_g2p = args.for_g2p self.reversed_phone_mapping = {} self.silence_words = set() def _process_pronunciations( self, word_pronunciations: typing.List[typing.Tuple[str, str]] ) -> PronunciationProbabilityCounter: """ Process an utterance's pronunciations and extract relevant count information Parameters ---------- word_pronunciations: list[tuple[str, tuple[str, ...]]] List of tuples containing the word integer ID and a list of the integer IDs of the phones """ counter = PronunciationProbabilityCounter() word_pronunciations = [("<s>", "")] + word_pronunciations + [("</s>", "")] for i, w_p in enumerate(word_pronunciations): if i != 0: word = word_pronunciations[i - 1][0] if word in self.silence_words: counter.silence_before_counts[w_p] += 1 else: counter.non_silence_before_counts[w_p] += 1 silence_check = w_p[0] in self.silence_words if not silence_check: counter.word_pronunciation_counts[w_p[0]][w_p[1]] += 1 if i != len(word_pronunciations) - 1: word = word_pronunciations[i + 1][0] if word in self.silence_words: counter.silence_following_counts[w_p] += 1 if i != len(word_pronunciations) - 2: next_w_p = word_pronunciations[i + 2] counter.ngram_counts[w_p, next_w_p]["silence"] += 1 else: next_w_p = word_pronunciations[i + 1] counter.non_silence_following_counts[w_p] += 1 counter.ngram_counts[w_p, next_w_p]["non_silence"] += 1 return counter def _run(self) -> typing.Generator[typing.Tuple[int, int, str]]: """Run the function""" self.phone_symbol_table = None with mfa_open(self.log_path, "w") as log_file, Session(self.db_engine()) as session: job = ( session.query(Job) .options(joinedload(Job.corpus, innerjoin=True), subqueryload(Job.dictionaries)) .filter(Job.id == self.job_name) .first() ) workflow: CorpusWorkflow = ( session.query(CorpusWorkflow) .filter(CorpusWorkflow.current == True) # noqa .first() ) phones = session.query(Phone.kaldi_label, Phone.mapping_id) for phone, mapping_id in phones: self.reversed_phone_mapping[mapping_id] = phone for d in job.dictionaries: utts = ( session.query(Utterance.id, Utterance.normalized_text) .join(Utterance.speaker) .filter(Utterance.job_id == self.job_name) .filter(Speaker.dictionary_id == d.id) ) self.utterance_texts = {} for u_id, text in utts: self.utterance_texts[u_id] = text if self.phone_symbol_table is None: self.phone_symbol_table = pywrapfst.SymbolTable.read_text( d.phone_symbol_table_path ) self.word_symbol_table = pywrapfst.SymbolTable.read_text(d.words_symbol_path) self.align_lexicon_fst = pynini.Fst.read(d.align_lexicon_path) self.clitic_marker = d.clitic_marker self.silence_words.add(d.silence_word) self.oov_word = d.oov_word self.optional_silence_phone = d.optional_silence_phone self.oov_phone = d.oov_phone self.position_dependent_phones = d.position_dependent_phones silence_words = ( session.query(Word.word) .filter(Word.dictionary_id == d.id) .filter(Word.word_type == WordType.silence) ) self.silence_words.update(x for x, in silence_words) ali_path = job.construct_path(workflow.working_directory, "ali", "ark", d.id) if not os.path.exists(ali_path): continue ctm_proc = subprocess.Popen( [ thirdparty_binary("ali-to-phones"), "--ctm-output", self.model_path, f"ark,s,cs:{ali_path}", "-", ], stderr=log_file, stdout=subprocess.PIPE, env=os.environ, encoding="utf8", ) for utterance, intervals in parse_ctm_output( ctm_proc, self.reversed_phone_mapping ): word_pronunciations = phones_to_prons( self.utterance_texts[utterance], intervals, self.align_lexicon_fst, d.word_pronunciations, self.word_symbol_table, self.phone_symbol_table, self.optional_silence_phone, oov_phone=self.oov_phone, oov_word=self.oov_word, silence_words=self.silence_words, position_dependent_phones=self.position_dependent_phones, ) word_pronunciations = [(x[0], " ".join(x[1])) for x in word_pronunciations] word_pronunciations = [ x if x[1] != self.oov_phone else (self.oov_word, self.oov_phone) for x in word_pronunciations ] if self.for_g2p: phones = [] for i, x in enumerate(word_pronunciations): if i > 0 and ( x[0].startswith(self.clitic_marker) or word_pronunciations[i - 1][0].endswith(self.clitic_marker) ): phones.pop(-1) else: phones.append("#1") phones.extend(x[1].split()) phones.append("#2") yield d.id, utterance, " ".join(phones) else: yield d.id, self._process_pronunciations(word_pronunciations) self.check_call(ctm_proc)
[docs] def compile_information_func( arguments: CompileInformationArguments, ) -> Dict[str, Union[List[str], float, int]]: """ Multiprocessing function for compiling information about alignment See Also -------- :meth:`.AlignMixin.compile_information` Main function that calls this function in parallel Parameters ---------- arguments: CompileInformationArguments Arguments for the function Returns ------- dict[str, Union[list[str], float, int]] Information about log-likelihood and number of unaligned files """ average_logdet_pattern = re.compile( r"Overall average logdet is (?P<logdet>[-.,\d]+) over (?P<frames>[.\d+e]+) frames" ) log_like_pattern = re.compile( r"^LOG .* Overall log-likelihood per frame is (?P<log_like>[-0-9.]+) over (?P<frames>\d+) frames.*$" ) decode_error_pattern = re.compile( r"^WARNING .* Did not successfully decode file (?P<utt>.*?), .*$" ) data = {"unaligned": [], "log_like": 0, "total_frames": 0} align_log_path = arguments.align_log_path if not os.path.exists(align_log_path): align_log_path = align_log_path.with_suffix(".fmllr.log") with mfa_open(arguments.log_path, "w") as log_file, mfa_open(align_log_path, "r") as f: log_file.write(f"Processing {align_log_path}...\n") for line in f: decode_error_match = re.match(decode_error_pattern, line) if decode_error_match: utt = decode_error_match.group("utt") data["unaligned"].append(utt) log_file.write(f"Unaligned: {utt}\n") continue log_like_match = re.search(log_like_pattern, line) if log_like_match: log_like = log_like_match.group("log_like") frames = log_like_match.group("frames") data["log_like"] = float(log_like) data["total_frames"] = int(frames) log_file.write(line) m = re.search(average_logdet_pattern, line) if m: logdet = float(m.group("logdet")) frames = float(m.group("frames")) data["logdet"] = logdet data["logdet_frames"] = frames log_file.write(line) return data
[docs] class AlignmentExtractionFunction(KaldiFunction): """ Multiprocessing function to collect phone alignments from the aligned lattice See Also -------- :meth:`.CorpusAligner.collect_alignments` Main function that calls this function in parallel :meth:`.CorpusAligner.alignment_extraction_arguments` Job method for generating arguments for this function :kaldi_src:`linear-to-nbest` Relevant Kaldi binary :kaldi_src:`lattice-determinize-pruned` Relevant Kaldi binary :kaldi_src:`lattice-align-words` Relevant Kaldi binary :kaldi_src:`lattice-to-phone-lattice` Relevant Kaldi binary :kaldi_src:`nbest-to-ctm` Relevant Kaldi binary :kaldi_steps:`get_train_ctm` Reference Kaldi script Parameters ---------- args: :class:`~montreal_forced_aligner.alignment.multiprocessing.AlignmentExtractionArguments` Arguments for the function """ def __init__(self, args: AlignmentExtractionArguments): super().__init__(args) self.model_path = args.model_path self.frame_shift = args.frame_shift self.utterance_begins = {} self.utterance_durations = {} self.reversed_phone_mapping = {} self.reversed_word_mapping = {} self.pronunciation_mapping = {} self.phone_mapping = {} self.silence_words = set() self.confidence = args.confidence self.transcription = args.transcription self.score_options = args.score_options
[docs] def cleanup_intervals( self, utterance_name, intervals: List[CtmInterval], ): """ Clean up phone intervals to remove silence Parameters ---------- intervals: list[:class:`~montreal_forced_aligner.data.CtmInterval`] Intervals to process Returns ------- list[:class:`~montreal_forced_aligner.data.CtmInterval`] Cleaned up intervals """ word_pronunciations = phones_to_prons( self.utterance_texts[utterance_name], intervals, self.align_lexicon_fst, self.word_pronunciations, self.word_symbol_table, self.phone_symbol_table, self.optional_silence_phone, self.transcription, oov_phone=self.oov_phone, oov_word=self.oov_word, silence_words=self.silence_words, position_dependent_phones=self.position_dependent_phones, ) actual_phone_intervals = [] actual_word_intervals = [] phone_word_mapping = [] utterance_begin = self.utterance_begins[utterance_name] utterance_duration = self.utterance_durations[utterance_name] if utterance_duration - intervals[-1].end < 0.05: intervals[-1].end = utterance_duration current_word_begin = None words_index = 0 current_phones = [] for interval in intervals: interval.begin += utterance_begin interval.end += utterance_begin if interval.label == self.optional_silence_phone: interval.label = self.phone_to_phone_id[interval.label] actual_phone_intervals.append(interval) actual_word_intervals.append( WordCtmInterval( interval.begin, interval.end, self.word_mapping[self.silence_word], self.pronunciation_mapping[ (self.silence_word, self.optional_silence_phone) ], ) ) phone_word_mapping.append(len(actual_word_intervals) - 1) current_word_begin = None current_phones = [] words_index += 1 continue if current_word_begin is None: current_word_begin = interval.begin current_phones.append(interval.label) try: cur_word = word_pronunciations[words_index] except IndexError: if self.transcription: break else: raise pronunciation = " ".join(cur_word[1]) current_pron = " ".join(current_phones) if self.position_dependent_phones: current_pron = re.sub(r"_[BIES]\b", "", current_pron) if current_pron == pronunciation: if ( pronunciation == self.oov_phone and (cur_word[0], pronunciation) not in self.pronunciation_mapping ): pron_id = self.pronunciation_mapping[(self.oov_word, pronunciation)] else: pron_id = self.pronunciation_mapping.get((cur_word[0], pronunciation), None) actual_word_intervals.append( WordCtmInterval( current_word_begin, interval.end, self.word_mapping[cur_word[0]], pron_id, ) ) for _ in range(len(current_phones)): phone_word_mapping.append(len(actual_word_intervals) - 1) current_word_begin = None current_phones = [] words_index += 1 interval.label = self.phone_to_phone_id[interval.label] actual_phone_intervals.append(interval) return actual_word_intervals, actual_phone_intervals, phone_word_mapping
[docs] def cleanup_g2p_intervals( self, utterance_name, intervals: List[CtmInterval], ): """ Clean up phone intervals to remove silence Parameters ---------- utterance_name: str Name of the current utterance intervals: list[:class:`~montreal_forced_aligner.data.CtmInterval`] Intervals to process Returns ------- list[:class:`~montreal_forced_aligner.data.CtmInterval`] Cleaned up intervals """ word_pronunciations = phones_to_prons( self.utterance_texts[utterance_name], intervals, self.align_lexicon_fst, self.word_pronunciations, self.word_symbol_table, self.phone_symbol_table, self.optional_silence_phone, oov_phone=self.oov_phone, oov_word=self.oov_word, use_g2p=True, silence_words=self.silence_words, position_dependent_phones=self.position_dependent_phones, ) actual_phone_intervals = [] actual_word_intervals = [] phone_word_mapping = [] utterance_begin = self.utterance_begins[utterance_name] current_word_begin = None words_index = 0 current_phones = [] for interval in intervals: interval.begin += utterance_begin interval.end += utterance_begin if interval.label == self.optional_silence_phone: interval.label = self.phone_to_phone_id[interval.label] actual_phone_intervals.append(interval) actual_word_intervals.append( WordCtmInterval( interval.begin, interval.end, self.word_mapping[self.silence_word], None, ) ) phone_word_mapping.append(len(actual_word_intervals) - 1) current_word_begin = None current_phones = [] continue if current_word_begin is None: current_word_begin = interval.begin current_phones.append(interval.label) cur_word = word_pronunciations[words_index] pronunciation = " ".join(cur_word[1]) current_pron = " ".join(current_phones) if self.position_dependent_phones: current_pron = re.sub(r"_[BIES]\b", "", current_pron) if current_pron == pronunciation: try: if ( pronunciation == self.oov_phone and (cur_word[0], pronunciation) not in self.pronunciation_mapping ): pron_id = self.pronunciation_mapping[(self.oov_word, pronunciation)] else: pron_id = self.pronunciation_mapping[(cur_word[0], pronunciation)] except KeyError: pron_id = None try: word_id = self.word_mapping[cur_word[0]] except KeyError: word_id = cur_word[0] actual_word_intervals.append( WordCtmInterval( current_word_begin, interval.end, word_id, pron_id, ) ) for _ in range(len(current_phones)): phone_word_mapping.append(len(actual_word_intervals) - 1) current_word_begin = None current_phones = [] words_index += 1 interval.label = self.phone_to_phone_id[interval.label] actual_phone_intervals.append(interval) return actual_word_intervals, actual_phone_intervals, phone_word_mapping
def _run(self) -> typing.Generator[typing.Tuple[int, List[CtmInterval], List[CtmInterval]]]: """Run the function""" align_lexicon_paths = {} self.phone_symbol_table = None with Session(self.db_engine()) as session, mfa_open(self.log_path, "w") as log_file: job: Job = ( session.query(Job) .options(joinedload(Job.corpus, innerjoin=True), subqueryload(Job.dictionaries)) .filter(Job.id == self.job_name) .first() ) workflow: CorpusWorkflow = ( session.query(CorpusWorkflow) .filter(CorpusWorkflow.current == True) # noqa .first() ) self.phone_to_phone_id = {} ds = session.query(Phone.kaldi_label, Phone.id, Phone.mapping_id).all() for phone, p_id, mapping_id in ds: self.reversed_phone_mapping[mapping_id] = phone self.phone_to_phone_id[phone] = p_id self.phone_mapping[phone] = mapping_id for d in job.dictionaries: columns = [Utterance.id, Utterance.begin, Utterance.duration] if d.use_g2p: columns.append(Utterance.normalized_character_text) else: columns.append(Utterance.normalized_text) utts = ( session.query(*columns) .join(Utterance.speaker) .filter(Utterance.job_id == self.job_name) .filter(Speaker.dictionary_id == d.id) ) self.utterance_begins = {} self.utterance_texts = {} for u_id, begin, duration, text in utts: self.utterance_begins[u_id] = begin self.utterance_durations[u_id] = duration self.utterance_texts[u_id] = text if self.phone_symbol_table is None: self.phone_symbol_table = pywrapfst.SymbolTable.read_text( d.phone_symbol_table_path ) self.align_lexicon_fst = pynini.Fst.read(d.align_lexicon_path) if d.use_g2p: self.word_symbol_table = pywrapfst.SymbolTable.read_text( d.grapheme_symbol_table_path ) self.align_lexicon_fst.invert() else: self.word_symbol_table = pywrapfst.SymbolTable.read_text(d.words_symbol_path) self.clitic_marker = d.clitic_marker self.silence_word = d.silence_word self.word_pronunciations = d.word_pronunciations self.oov_word = d.oov_word self.oov_phone = "spn" self.position_dependent_phones = d.position_dependent_phones self.optional_silence_phone = d.optional_silence_phone if self.transcription or self.confidence: align_lexicon_paths[d.id] = d.align_lexicon_int_path else: align_lexicon_paths[d.id] = d.align_lexicon_path silence_words = ( session.query(Word.id) .filter(Word.dictionary_id == d.id) .filter(Word.word_type == WordType.silence) ) self.silence_words.update(x for x, in silence_words) words = session.query(Word.word, Word.id, Word.mapping_id).filter( Word.dictionary_id == d.id ) self.word_mapping = {} self.reversed_word_mapping = {} for w, w_id, m_id in words: self.word_mapping[w] = w_id self.reversed_word_mapping[m_id] = w self.pronunciation_mapping = {} pronunciations = ( session.query(Word.word, Pronunciation.pronunciation, Pronunciation.id) .join(Pronunciation.word) .filter(Word.dictionary_id == d.id) ) for w, pron, p_id in pronunciations: self.pronunciation_mapping[(w, pron)] = p_id lat_path = job.construct_path( workflow.working_directory, "lat.carpa.rescored", "ark", d.id ) if not os.path.exists(lat_path): lat_path = job.construct_path(workflow.working_directory, "lat", "ark", d.id) ali_path = job.construct_path(workflow.working_directory, "ali", "ark", d.id) like_path = job.construct_path(workflow.working_directory, "like", "ark", d.id) if self.transcription: self.utterance_texts = {} lat_align_proc = subprocess.Popen( [ thirdparty_binary("lattice-align-words-lexicon"), align_lexicon_paths[d.id], self.model_path, f"ark,s,cs:{lat_path}", "ark:-", ], stderr=log_file, stdout=subprocess.PIPE, env=os.environ, ) one_best_proc = subprocess.Popen( [ thirdparty_binary("lattice-best-path"), f"--acoustic-scale={self.score_options['acoustic_scale']}", "ark,s,cs:-", "ark,t:-", f"ark:{ali_path}", ], stderr=log_file, stdin=lat_align_proc.stdout, stdout=subprocess.PIPE, env=os.environ, ) for line in one_best_proc.stdout: line = line.strip().decode("utf8").split() utt_id = int(line.pop(0).split("-")[1]) text = " ".join([self.reversed_word_mapping[int(x)] for x in line]) self.utterance_texts[utt_id] = text if self.confidence and os.path.exists(lat_path): lat_align_proc = subprocess.Popen( [ thirdparty_binary("lattice-align-words-lexicon"), align_lexicon_paths[d.id], self.model_path, f"ark,s,cs:{lat_path}", "ark:-", ], stderr=log_file, stdout=subprocess.PIPE, env=os.environ, ) phone_lat_proc = subprocess.Popen( [ thirdparty_binary("lattice-to-phone-lattice"), "--replace-words=true", self.model_path, "ark,s,cs:-", "ark:-", ], stderr=log_file, stdin=lat_align_proc.stdout, stdout=subprocess.PIPE, env=os.environ, ) ctm_proc = subprocess.Popen( [ thirdparty_binary("lattice-to-ctm-conf"), f"--acoustic-scale={self.score_options['acoustic_scale']}", "ark,s,cs:-", "-", ], stderr=log_file, stdin=phone_lat_proc.stdout, stdout=subprocess.PIPE, env=os.environ, encoding="utf8", ) for utterance, intervals in parse_ctm_output( ctm_proc, self.reversed_phone_mapping ): try: ( word_intervals, phone_intervals, phone_word_mapping, ) = self.cleanup_intervals(utterance, intervals) except pywrapfst.FstOpError: log_file.write(f"Error for {utterance}\n") log_file.write(f"{self.utterance_texts[utterance]}\n") log_file.write(f"{' '.join(x.label for x in intervals)}\n") log_file.flush() continue yield utterance, word_intervals, phone_intervals, phone_word_mapping self.check_call(ctm_proc) else: ctm_proc = subprocess.Popen( [ thirdparty_binary("ali-to-phones"), "--ctm-output", f"--frame-shift={self.frame_shift}", self.model_path, f"ark,s,cs:{ali_path}", "-", ], stderr=log_file, stdout=subprocess.PIPE, env=os.environ, encoding="utf8", ) if like_path.exists(): like_proc = subprocess.Popen( [ thirdparty_binary("copy-vector"), f"ark,s,cs:{like_path}", "ark,t:-", ], stderr=log_file, stdout=subprocess.PIPE, env=os.environ, ) like_gen = read_feats(like_proc) for utterance, intervals in parse_ctm_output( ctm_proc, self.reversed_phone_mapping ): if like_path.exists(): utt, loglikes = next(like_gen) for interval in intervals: begin_frame = int(round(interval.begin / self.frame_shift)) end_frame = int(round(interval.end / self.frame_shift)) if begin_frame == end_frame: end_frame += 1 interval.confidence = round( float(np.mean(loglikes[begin_frame:end_frame])), 6 ) if not d.use_g2p: ( word_intervals, phone_intervals, phone_word_mapping, ) = self.cleanup_intervals(utterance, intervals) else: try: ( word_intervals, phone_intervals, phone_word_mapping, ) = self.cleanup_g2p_intervals(utterance, intervals) except pywrapfst.FstOpError: continue yield utterance, word_intervals, phone_intervals, phone_word_mapping self.check_call(ctm_proc)
[docs] class ExportTextGridProcessWorker(mp.Process): """ Multiprocessing worker for exporting TextGrids See Also -------- :meth:`.CorpusAligner.collect_alignments` Main function that runs this worker in parallel Parameters ---------- for_write_queue: :class:`~multiprocessing.Queue` Input queue of files to export stopped: :class:`~montreal_forced_aligner.utils.Stopped` Stop check for processing finished_processing: :class:`~montreal_forced_aligner.utils.Stopped` Input signal that all jobs have been added and no more new ones will come in textgrid_errors: dict[str, str] Dictionary for storing errors encountered arguments: :class:`~montreal_forced_aligner.alignment.multiprocessing.ExportTextGridArguments` Arguments to pass to the TextGrid export function exported_file_count: :class:`~montreal_forced_aligner.utils.Counter` Counter for exported files """ def __init__( self, db_string: str, for_write_queue: mp.Queue, return_queue: mp.Queue, stopped: Stopped, finished_adding: Stopped, arguments: ExportTextGridArguments, exported_file_count: Counter, ): mp.Process.__init__(self) self.db_string = db_string self.for_write_queue = for_write_queue self.return_queue = return_queue self.stopped = stopped self.finished_adding = finished_adding self.finished_processing = Stopped() self.output_directory = arguments.output_directory self.output_format = arguments.output_format self.export_frame_shift = arguments.export_frame_shift self.log_path = arguments.log_path self.include_original_text = arguments.include_original_text self.cleanup_textgrids = arguments.cleanup_textgrids self.clitic_marker = arguments.clitic_marker self.exported_file_count = exported_file_count
[docs] def run(self) -> None: """Run the exporter function""" db_engine = sqlalchemy.create_engine( self.db_string, poolclass=sqlalchemy.NullPool, pool_reset_on_return=None, logging_name=f"{type(self).__name__}_engine", isolation_level="AUTOCOMMIT", ).execution_options(logging_token=f"{type(self).__name__}_engine") with mfa_open(self.log_path, "w") as log_file, Session(db_engine) as session: workflow: CorpusWorkflow = ( session.query(CorpusWorkflow) .filter(CorpusWorkflow.current == True) # noqa .first() ) log_file.write(f"Exporting TextGrids for Workflow ID: {workflow.id}\n") log_file.write(f"Output directory: {self.output_directory}\n") log_file.write(f"Output format: {self.output_format}\n") log_file.write(f"Frame shift: {self.export_frame_shift}\n") log_file.write(f"Include original text: {self.include_original_text}\n") log_file.write(f"Clean up textgrids: {self.cleanup_textgrids}\n") while True: try: ( file_id, name, relative_path, duration, text_file_path, ) = self.for_write_queue.get(timeout=1) except Empty: if self.finished_adding.stop_check(): self.finished_processing.stop() break continue if self.stopped.stop_check(): continue try: output_path = construct_output_path( name, relative_path, self.output_directory, text_file_path, self.output_format, ) data = construct_output_tiers( session, file_id, workflow, self.cleanup_textgrids, self.clitic_marker, self.include_original_text, ) export_textgrid( data, output_path, duration, self.export_frame_shift, self.output_format ) self.return_queue.put(1) except Exception: exc_type, exc_value, exc_traceback = sys.exc_info() self.return_queue.put( AlignmentExportError( output_path, traceback.format_exception(exc_type, exc_value, exc_traceback), ) ) self.stopped.stop() log_file.write("Done!\n")