Source code for montreal_forced_aligner.utils

"""
Utility functions
=================

"""
from __future__ import annotations

import datetime
import logging
import multiprocessing as mp
import os
import re
import shutil
import subprocess
import time
import typing
from pathlib import Path
from queue import Empty
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import sqlalchemy
from sqlalchemy.orm import Session

from montreal_forced_aligner.abc import KaldiFunction
from montreal_forced_aligner.config import GLOBAL_CONFIG
from montreal_forced_aligner.data import CtmInterval, DatasetType, MfaArguments
from montreal_forced_aligner.db import Corpus, Dictionary
from montreal_forced_aligner.exceptions import (
    DictionaryError,
    KaldiProcessingError,
    MultiprocessingError,
    ThirdpartyError,
)
from montreal_forced_aligner.helper import mfa_open
from montreal_forced_aligner.textgrid import process_ctm_line

__all__ = [
    "check_third_party",
    "thirdparty_binary",
    "log_kaldi_errors",
    "get_mfa_version",
    "parse_logs",
    "inspect_database",
    "Counter",
    "Stopped",
    "ProcessWorker",
    "ProgressCallback",
    "KaldiProcessWorker",
    "parse_ctm_output",
    "read_feats",
    "run_mp",
    "run_non_mp",
    "run_kaldi_function",
]
canary_kaldi_bins = [
    "compute-mfcc-feats",
    "compute-and-process-kaldi-pitch-feats",
    "gmm-align-compiled",
    "gmm-est-fmllr",
    "gmm-est-fmllr-gpost",
    "lattice-oracle",
    "gmm-latgen-faster",
    "fstdeterminizestar",
    "fsttablecompose",
    "gmm-rescore-lattice",
]

logger = logging.getLogger("mfa")


def inspect_database(name: str) -> DatasetType:
    """
    Inspect the database file to generate its DatasetType

    Parameters
    ----------
    name: str
        Name of database

    Returns
    -------
    DatasetType
        Dataset type of the database
    """

    string = f"postgresql+psycopg2://@/{name}?host={GLOBAL_CONFIG.database_socket}"
    try:
        engine = sqlalchemy.create_engine(
            string,
            poolclass=sqlalchemy.NullPool,
            pool_reset_on_return=None,
            isolation_level="AUTOCOMMIT",
            logging_name="inspect_dataset_engine",
        ).execution_options(logging_token="inspect_dataset_engine")
        with Session(engine) as session:
            corpus = session.query(Corpus).first()
            dictionary = session.query(Dictionary).first()
            if corpus is None and dictionary is None:
                return DatasetType.NONE
            elif corpus is None:
                return DatasetType.DICTIONARY
            elif dictionary is None:
                if corpus.has_sound_files:
                    return DatasetType.ACOUSTIC_CORPUS
                else:
                    return DatasetType.TEXT_CORPUS
            if corpus.has_sound_files:
                return DatasetType.ACOUSTIC_CORPUS_WITH_DICTIONARY
            else:
                return DatasetType.TEXT_CORPUS_WITH_DICTIONARY
    except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError):
        return DatasetType.NONE


def get_class_for_dataset_type(dataset_type: DatasetType):
    """
    Generate the corresponding MFA class for a given DatasetType

    Parameters
    ----------
    dataset_type: DatasetType
        Dataset type for the class

    Returns
    -------
    Union[None, AcousticCorpus, TextCorpus, AcousticCorpusWithPronunciations, DictionaryTextCorpus,MultispeakerDictionary]
        Class to use for the current database file
    """
    from montreal_forced_aligner.corpus.acoustic_corpus import (
        AcousticCorpus,
        AcousticCorpusWithPronunciations,
    )
    from montreal_forced_aligner.corpus.text_corpus import DictionaryTextCorpus, TextCorpus
    from montreal_forced_aligner.dictionary import MultispeakerDictionary

    mapping = {
        DatasetType.NONE: None,
        DatasetType.ACOUSTIC_CORPUS: AcousticCorpus,
        DatasetType.TEXT_CORPUS: TextCorpus,
        DatasetType.ACOUSTIC_CORPUS_WITH_DICTIONARY: AcousticCorpusWithPronunciations,
        DatasetType.TEXT_CORPUS_WITH_DICTIONARY: DictionaryTextCorpus,
        DatasetType.DICTIONARY: MultispeakerDictionary,
    }
    return mapping[dataset_type]


def parse_dictionary_file(
    path: Path,
) -> typing.Generator[
    typing.Tuple[
        str,
        typing.List[str],
        typing.Optional[float],
        typing.Optional[float],
        typing.Optional[float],
        typing.Optional[float],
    ]
]:
    """
    Parses a lexicon file and yields parsed pronunciation lines

    Parameters
    ----------
    path: :class:`~pathlib.Path`
        Path to lexicon file

    Yields
    ------
    str
        Orthographic word
    list[str]
        Pronunciation
    float or None
        Pronunciation probability
    float or None
        Probability of silence following the pronunciation
    float or None
        Correction factor for silence before the pronunciation
    float or None
        Correction factor for no silence before the pronunciation
    """
    prob_pattern = re.compile(r"\b\d+\.\d+\b")
    with mfa_open(path) as f:
        for i, line in enumerate(f):
            line = line.strip()
            if not line:
                continue
            line = line.split()
            if len(line) <= 1:
                raise DictionaryError(
                    f'Error parsing line {i} of {path}: "{line}" did not have a pronunciation'
                )
            word = line.pop(0)
            prob = None
            silence_after_prob = None
            silence_before_correct = None
            non_silence_before_correct = None
            if prob_pattern.match(line[0]):
                prob = float(line.pop(0))
                if prob_pattern.match(line[0]):
                    silence_after_prob = float(line.pop(0))
                    if prob_pattern.match(line[0]):
                        silence_before_correct = float(line.pop(0))
                        if prob_pattern.match(line[0]):
                            non_silence_before_correct = float(line.pop(0))
            pron = tuple(line)
            yield word, pron, prob, silence_after_prob, silence_before_correct, non_silence_before_correct


def parse_ctm_output(
    proc: subprocess.Popen, reversed_phone_mapping: Dict[int, Any], raw_id: bool = False
) -> typing.Generator[typing.Tuple[typing.Union[int, str], typing.List[CtmInterval]]]:
    """
    Parse stdout of a process into intervals grouped by utterance

    Parameters
    ----------
    proc: :class:`subprocess.Popen`
    reversed_phone_mapping: dict[int, Any]
        Mapping from kaldi integer IDs to phones
    raw_id: bool
        Flag for returning the kaldi internal ID of the utterance rather than its integer ID

    Yields
    -------
    int or str
        Utterance ID
    list[:class:`~montreal_forced_aligner.data.CtmInterval`]
        List of CTM intervals for the utterance
    """
    current_utt = None
    intervals = []
    for line in proc.stdout:
        line = line.strip()
        if not line:
            continue
        try:
            utt, interval = process_ctm_line(line, reversed_phone_mapping, raw_id=raw_id)
        except ValueError:
            continue
        if current_utt is None:
            current_utt = utt
        if current_utt != utt:
            yield current_utt, intervals
            intervals = []
            current_utt = utt
        intervals.append(interval)
    if intervals:
        yield current_utt, intervals


def get_mfa_version() -> str:
    """
    Get the current MFA version

    Returns
    -------
    str
        MFA version
    """
    try:
        from ._version import version as __version__  # noqa
    except ImportError:
        __version__ = "2.0.0"
    return __version__


def check_third_party():
    """
    Checks whether third party software is available on the path

    Raises
    -------
    :class:`~montreal_forced_aligner.exceptions.ThirdpartyError`
    """
    bin_path = shutil.which("sox")
    if bin_path is None:
        raise ThirdpartyError("sox")
    bin_path = shutil.which("initdb")
    if bin_path is None:
        raise ThirdpartyError("initdb")
    bin_path = shutil.which("fstcompile")
    if bin_path is None:
        raise ThirdpartyError("fstcompile", open_fst=True)

    p = subprocess.run(["fstcompile", "--help"], capture_output=True, text=True)
    if p.returncode == 1 and p.stderr:
        raise ThirdpartyError("fstcompile", open_fst=True, error_text=p.stderr)
    for fn in canary_kaldi_bins:
        try:
            p = subprocess.run([thirdparty_binary(fn), "--help"], capture_output=True, text=True)
        except Exception as e:
            raise ThirdpartyError(fn, error_text=str(e))
        if p.returncode == 1 and p.stderr:
            raise ThirdpartyError(fn, error_text=p.stderr)


[docs] def thirdparty_binary(binary_name: str) -> str: """ Generate full path to a given binary name Notes ----- With the move to conda, this function is deprecated as conda will manage the path much better Parameters ---------- binary_name: str Executable to run Returns ------- str Full path to the executable """ bin_path = shutil.which(binary_name) if bin_path is None: if binary_name in ["fstcompile", "fstarcsort", "fstconvert"]: raise ThirdpartyError(binary_name, open_fst=True) else: raise ThirdpartyError(binary_name) if " " in bin_path: return f'"{bin_path}"' return bin_path
[docs] def log_kaldi_errors(error_logs: List[str]) -> None: """ Save details of Kaldi processing errors to a logger Parameters ---------- error_logs: list[str] Kaldi log files with errors """ logger.debug(f"There were {len(error_logs)} kaldi processing files that had errors:") for path in error_logs: logger.debug("") logger.debug(path) with mfa_open(path, "r") as f: for line in f: logger.debug("\t" + line.strip())
def read_feats( proc: subprocess.Popen, raw_id=False ) -> typing.Generator[typing.Union[str, int], np.array]: """ Inspired by https://github.com/it-muslim/kaldi-helpers/blob/master/kaldi-helpers/kaldi_io.py#L87 Reading from stdout, import feats (or feats-like) data as a numpy array As feats are generated "on-fly" in kaldi, there is no a feats file (except most simple cases like raw mfcc, plp or fbank). So, that is why we take feats as a command rather that a file path. Can be applied to other commands (like gmm-compute-likes) generating an output in same format as feats, i.e: utterance_id_1 [ 70.31843 -2.872698 -0.06561285 22.71824 -15.57525 ... 78.39457 -1.907646 -1.593253 23.57921 -14.74229 ... ... 57.27236 -16.17824 -15.33368 -5.945696 0.04276848 ... -0.5812851 ] utterance_id_2 [ 64.00951 -8.952017 4.134113 33.16264 11.09073 ... ... Parameters ---------- proc : subprocess.Popen A process that generates features or feature-like specifications Yields ------- int or str Utterance ID numpy.array features """ feats = [] # current_row = 0 current_id = None for line in proc.stdout: line = line.decode("ascii").strip() if "[" in line and "]" in line: line = line.replace("]", "").replace("[", "").split() ids = line.pop(0) if raw_id: utt_id = ids else: utt_id = int(ids.split("-")[-1]) feats = np.array([float(x) for x in line]) yield utt_id, feats feats = [] continue elif "[" in line: ids = line.strip().split()[0] if raw_id: utt_id = ids else: utt_id = int(ids.split("-")[-1]) if current_id is None: current_id = utt_id if current_id != utt_id: feats = np.array(feats) yield current_id, feats feats = [] current_id = utt_id continue if not line: continue feats.append([float(x) for x in line.replace("]", "").split()]) if current_id is not None: feats = np.array(feats) yield current_id, feats
[docs] def parse_logs(log_directory: Path) -> None: """ Parse the output of a Kaldi run for any errors and raise relevant MFA exceptions Parameters ---------- log_directory: str Log directory to parse Raises ------ KaldiProcessingError If any log files contained error lines """ error_logs = [] for log_path in log_directory.iterdir(): if log_path.is_dir(): continue if log_path.suffix != ".log": continue with mfa_open(log_path, "r") as f: for line in f: line = line.strip() if "error while loading shared libraries: libopenblas.so.0" in line: raise ThirdpartyError("libopenblas.so.0", open_blas=True) for libc_version in ["GLIBC_2.27", "GLIBCXX_3.4.20"]: if libc_version in line: raise ThirdpartyError(libc_version, libc=True) if "sox FAIL formats" in line: f = line.split(" ")[-1] raise ThirdpartyError(f, sox=True) if line.startswith("ERROR") or line.startswith("ASSERTION_FAILED"): error_logs.append(log_path) break if error_logs: raise KaldiProcessingError(error_logs)
[docs] class Counter(object): """ Multiprocessing counter object for keeping track of progress Attributes ---------- val: :func:`~multiprocessing.Value` Integer to increment lock: :class:`~multiprocessing.Lock` Lock for process safety """ def __init__(self, init_val: int = 0): self.val = mp.Value("i", init_val) self.lock = mp.Lock()
[docs] def increment(self, value=1) -> None: """Increment the counter""" with self.lock: self.val.value += value
[docs] def value(self) -> int: """Get the current value of the counter""" with self.lock: return self.val.value
class ProgressCallback(object): """ Class for sending progress indications back to the main process """ def __init__(self, callback=None, total_callback=None): self._total = 0 self.callback = callback self.total_callback = total_callback self._progress = 0 self.callback_interval = 1 self.lock = mp.Lock() self.start_time = None @property def total(self) -> int: """Total entries to process""" with self.lock: return self._total @property def progress(self) -> int: """Current number of entries processed""" with self.lock: return self._progress @property def progress_percent(self) -> float: """Current progress as percetage""" with self.lock: if not self._total: return 0.0 return self._progress / self._total def update_total(self, total: int) -> None: """ Update the total for the callback Parameters ---------- total: int New total """ with self.lock: if self._total == 0 and total != 0: self.start_time = time.time() self._total = total if self.total_callback is not None: self.total_callback(self._total) def set_progress(self, progress: int) -> None: """ Update the number of entries processed for the callback Parameters ---------- progress: int New progress """ with self.lock: self._progress = progress def increment_progress(self, increment: int) -> None: """ Increment the number of entries processed for the callback Parameters ---------- increment: int Update the progress by this amount """ with self.lock: self._progress += increment if self.callback is not None: current_time = time.time() current_duration = current_time - self.start_time time_per_iteration = current_duration / self._progress remaining_iterations = self._total - self._progress remaining_time = datetime.timedelta( seconds=int(time_per_iteration * remaining_iterations) ) self.callback(self._progress, str(remaining_time))
[docs] class Stopped(object): """ Multiprocessing class for detecting whether processes should stop processing and exit ASAP Attributes ---------- val: :func:`~multiprocessing.Value` 0 if not stopped, 1 if stopped lock: :class:`~multiprocessing.Lock` Lock for process safety _source: multiprocessing.Value 1 if it was a Ctrl+C event that stopped it, 0 otherwise """ def __init__(self, initval: Union[bool, int] = False): self.val = mp.Value("i", initval) self.lock = mp.Lock() self._source = mp.Value("i", 0)
[docs] def reset(self) -> None: """Signal that work should stop asap""" with self.lock: self.val.value = False
[docs] def stop(self) -> None: """Signal that work should stop asap""" with self.lock: self.val.value = True
[docs] def stop_check(self) -> int: """Check whether a process should stop""" with self.lock: return self.val.value
[docs] def set_sigint_source(self) -> None: """Set the source as a ctrl+c""" with self.lock: self._source.value = True
[docs] def source(self) -> int: """Get the source value""" with self.lock: return self._source.value
[docs] class ProcessWorker(mp.Process): """ Multiprocessing function work Parameters ---------- job_name: int Integer number of job job_q: :class:`~multiprocessing.Queue` Job queue to pull arguments from function: Callable Multiprocessing function to call on arguments from job_q return_dict: dict Dictionary for collecting errors stopped: :class:`~montreal_forced_aligner.utils.Stopped` Stop check return_info: dict[int, Any], optional Optional dictionary to fill if the function should return information to main thread """ def __init__( self, job_name: int, job_q: mp.Queue, function: Callable, return_q: mp.Queue, stopped: Stopped, ): mp.Process.__init__(self) self.job_name = job_name self.function = function self.job_q = job_q self.return_q = return_q self.stopped = stopped self.finished_processing = Stopped()
[docs] def run(self) -> None: """ Run through the arguments in the queue apply the function to them """ while True: try: arguments = self.job_q.get(timeout=1) except Empty: self.finished_processing.stop() break try: if isinstance(arguments, MfaArguments): result = self.function(arguments) else: result = self.function(*arguments) self.return_q.put((self.job_name, result)) except Exception as e: self.stopped.stop() if isinstance(e, (KaldiProcessingError, MultiprocessingError)): e.job_name = self.job_name self.return_q.put((self.job_name, e))
class KaldiProcessWorker(mp.Process): """ Multiprocessing function work Parameters ---------- job_name: int Integer number of job return_q: :class:`~multiprocessing.Queue` Queue for returning results function: KaldiFunction Multiprocessing function to call on arguments from job_q error_dict: dict Dictionary for collecting errors stopped: :class:`~montreal_forced_aligner.utils.Stopped` Stop check """ def __init__( self, job_name: int, return_q: mp.Queue, function: KaldiFunction, stopped: Stopped, ): mp.Process.__init__(self) self.job_name = job_name self.function = function self.return_q = return_q self.stopped = stopped self.finished = Stopped() def run(self) -> None: """ Run through the arguments in the queue apply the function to them """ os.environ["OMP_NUM_THREADS"] = f"{GLOBAL_CONFIG.current_profile.blas_num_threads}" os.environ["OPENBLAS_NUM_THREADS"] = f"{GLOBAL_CONFIG.current_profile.blas_num_threads}" os.environ["MKL_NUM_THREADS"] = f"{GLOBAL_CONFIG.current_profile.blas_num_threads}" try: for result in self.function.run(): self.return_q.put(result) except Exception as e: self.stopped.stop() if isinstance(e, KaldiProcessingError): e.job_name = self.job_name self.return_q.put(e) finally: self.finished.stop() def run_kaldi_function(function, arguments, progress_callback, stopped: Stopped = None): if stopped is None: stopped = Stopped() if GLOBAL_CONFIG.use_mp: error_dict = {} return_queue = mp.Queue(10000) procs = [] for i, args in enumerate(arguments): f = function(args) p = KaldiProcessWorker(i, return_queue, f, stopped) procs.append(p) p.start() while True: try: result = return_queue.get(timeout=1) if isinstance(result, Exception): error_dict[getattr(result, "job_name", 0)] = result continue if stopped.stop_check(): continue except Empty: for proc in procs: if not proc.finished.stop_check(): break else: break continue yield result progress_callback(1) for p in procs: p.join() if error_dict: for v in error_dict.values(): raise v else: for args in arguments: f = function(args) for result in f.run(): if stopped.stop_check(): break yield result progress_callback(1)
[docs] def run_non_mp( function: Callable, argument_list: List[Union[Tuple[Any, ...], MfaArguments]], log_directory: str, return_info: bool = False, ) -> Optional[Dict[Any, Any]]: """ Similar to :func:`run_mp`, but no additional processes are used and the jobs are evaluated in sequential order Parameters ---------- function: Callable Multiprocessing function to evaluate argument_list: list List of arguments to process log_directory: str Directory that all log information from the processes goes to return_info: dict, optional If the function returns information, supply the return dict to populate Returns ------- dict, optional If the function returns information, returns the dictionary it was supplied with """ if return_info: info = {} for i, args in enumerate(argument_list): if isinstance(args, MfaArguments): info[i] = function(args) else: info[i] = function(*args) parse_logs(log_directory) return info for args in argument_list: if isinstance(args, MfaArguments): function(args) else: function(*args) parse_logs(log_directory)
[docs] def run_mp( function: Callable, argument_list: List[Union[Tuple[Any, ...], MfaArguments]], log_directory: str, return_info: bool = False, ) -> Optional[Dict[int, Any]]: """ Apply a function for each job in parallel Parameters ---------- function: Callable Multiprocessing function to apply argument_list: list Arguments for each job log_directory: str Directory that all log information from the processes goes to return_info: dict, optional If the function returns information, supply the return dict to populate """ os.environ["OMP_NUM_THREADS"] = f"{GLOBAL_CONFIG.current_profile.blas_num_threads}" os.environ["OPENBLAS_NUM_THREADS"] = f"{GLOBAL_CONFIG.current_profile.blas_num_threads}" os.environ["MKL_NUM_THREADS"] = f"{GLOBAL_CONFIG.current_profile.blas_num_threads}" stopped = Stopped() job_queue = mp.Queue() return_queue = mp.Queue() error_dict = {} info = {} for a in argument_list: job_queue.put(a) procs = [] for i in range(len(argument_list)): p = ProcessWorker(i, job_queue, function, return_queue, stopped) procs.append(p) p.start() while True: try: job_name, result = return_queue.get(timeout=1) if stopped.stop_check(): continue except Empty: for proc in procs: if not proc.finished_processing.stop_check(): break else: break continue if isinstance(result, (KaldiProcessingError, MultiprocessingError)): error_dict[job_name] = result continue info[job_name] = result for p in procs: p.join() if error_dict: for v in error_dict.values(): raise v parse_logs(log_directory) if return_info: return info