Source code for montreal_forced_aligner.helper

"""
Helper functions
================

"""
from __future__ import annotations

import itertools
import json
import logging
import re
import typing
from contextlib import contextmanager
from pathlib import Path
from typing import TYPE_CHECKING

import dataclassy
import numpy
import yaml
from rich.console import Console
from rich.logging import RichHandler
from rich.theme import Theme

if TYPE_CHECKING:
    from montreal_forced_aligner.abc import MetaDict


__all__ = [
    "comma_join",
    "make_safe",
    "make_scp_safe",
    "load_scp",
    "load_scp_safe",
    "score_wer",
    "score_g2p",
    "edit_distance",
    "output_mapping",
    "parse_old_features",
    "make_re_character_set_safe",
    "split_phone_position",
    "configure_logger",
    "mfa_open",
    "load_configuration",
    "format_correction",
    "format_probability",
    "load_evaluation_mapping",
]


console = Console(
    theme=Theme(
        {
            "logging.level.debug": "cyan",
            "logging.level.info": "green",
            "logging.level.warning": "yellow",
            "logging.level.error": "red",
        }
    ),
    stderr=True,
)


@contextmanager
def mfa_open(
    path: typing.Union[Path, str],
    mode: str = "r",
    encoding: str = "utf8",
    newline: typing.Optional[str] = "",
):
    if "r" in mode:
        if "b" in mode:
            file = open(path, mode)
        else:
            file = open(path, mode, encoding=encoding)
    else:
        if "b" in mode:
            file = open(path, mode)
        else:
            file = open(path, mode, encoding=encoding, newline=newline)
    try:
        yield file
    finally:
        file.close()


def load_configuration(config_path: typing.Union[str, Path]) -> typing.Dict[str, typing.Any]:
    """
    Load a configuration file

    Parameters
    ----------
    config_path: :class:`~pathlib.Path`
        Path to yaml or json configuration file

    Returns
    -------
    dict[str, Any]
        Configuration dictionary
    """
    data = {}
    if not isinstance(config_path, Path):
        config_path = Path(config_path)
    with mfa_open(config_path, "r") as f:
        if config_path.suffix == ".yaml":
            data = yaml.load(f, Loader=yaml.Loader)
        elif config_path.suffix == ".json":
            data = json.load(f)
    if not data:
        return {}
    return data


def split_phone_position(phone_label: str) -> typing.List[str]:
    """
    Splits a phone label into its original phone and it's positional label

    Parameters
    ----------
    phone_label: str
        Phone label

    Returns
    -------
    List[str]
        Phone and position
    """
    phone = phone_label
    pos = None
    try:
        phone, pos = phone_label.rsplit("_", maxsplit=1)
    except ValueError:
        pass
    return phone, pos


def parse_old_features(config: MetaDict) -> MetaDict:
    """
    Backwards compatibility function to parse old feature configuration blocks

    Parameters
    ----------
    config: dict[str, Any]
        Configuration parameters

    Returns
    -------
    dict[str, Any]
        Up to date versions of feature blocks
    """
    feature_key_remapping = {
        "type": "feature_type",
        "deltas": "uses_deltas",
    }
    skip_keys = ["lda", "fmllr"]
    if "features" in config:
        for key in skip_keys:
            if key in config["features"]:
                del config["features"][key]
        for key, new_key in feature_key_remapping.items():
            if key in config["features"]:
                config["features"][new_key] = config["features"][key]
                del config["features"][key]
    else:
        for key in skip_keys:
            if key in config:
                del config[key]
        for key, new_key in feature_key_remapping.items():
            if key in config:
                config[new_key] = config[key]
                del config[key]
    return config


def configure_logger(identifier: str, log_file: typing.Optional[Path] = None) -> None:
    """
    Configure logging for the given identifier

    Parameters
    ----------
    identifier: str
        Logger identifier
    log_file: str
        Path to file to write all messages to
    """
    from montreal_forced_aligner.config import MfaConfiguration

    config = MfaConfiguration()
    logger = logging.getLogger(identifier)
    logger.setLevel(logging.DEBUG)
    if log_file is not None:
        file_handler = logging.FileHandler(log_file, encoding="utf8")
        file_handler.setLevel(logging.DEBUG)
        formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)
    if not config.current_profile.quiet:
        handler = RichHandler(
            rich_tracebacks=True, log_time_format="", console=console, show_path=False
        )
        if config.current_profile.verbose:
            handler.setLevel(logging.DEBUG)
        else:
            handler.setLevel(logging.INFO)
        handler.setFormatter(logging.Formatter("%(message)s"))
        logger.addHandler(handler)


[docs] def comma_join(sequence: typing.List[typing.Any]) -> str: """ Helper function to combine a list into a human-readable expression with commas and a final "and" separator Parameters ---------- sequence: list[Any] Items to join together into a list Returns ------- str Joined items in list format """ if len(sequence) < 3: return " and ".join(sequence) return f"{', '.join(sequence[:-1])}, and {sequence[-1]}"
def make_re_character_set_safe( characters: typing.Collection[str], extra_strings: typing.Optional[typing.List[str]] = None ) -> str: """ Construct a character set string for use in regex, escaping necessary characters and moving "-" to the initial position Parameters ---------- characters: Collection[str] Characters to compile extra_strings: list[str], optional Optional other strings to put in the character class Returns ------- str Character set specifier for re functions """ characters = sorted(characters) extra = "" if "-" in characters: extra = "-" characters = [x for x in characters if x != "-"] if extra_strings: extra += "".join(extra_strings) return f"[{extra}{re.escape(''.join(characters))}]"
[docs] def make_safe(element: typing.Any) -> str: """ Helper function to make an element a string Parameters ---------- element: Any Element to recursively turn into a string Returns ------- str All elements combined into a string """ if isinstance(element, list): return " ".join(map(make_safe, element)) return str(element)
[docs] def make_scp_safe(string: str) -> str: """ Helper function to make a string safe for saving in Kaldi scp files. They use space as a delimiter, so any spaces in the string will be converted to "_MFASPACE_" to preserve them Parameters ---------- string: str Text to escape Returns ------- str Escaped text """ return str(string).replace(" ", "_MFASPACE_")
[docs] def load_scp_safe(string: str) -> str: """ Helper function to load previously made safe text. All instances of "_MFASPACE_" will be converted to a regular space character Parameters ---------- string: str String to convert Returns ------- str Converted string """ return string.replace("_MFASPACE_", " ")
[docs] def output_mapping( mapping: typing.Dict[str, typing.Any], path: Path, skip_safe: bool = False ) -> None: """ Helper function to save mapping information (i.e., utt2spk) in Kaldi scp format CorpusMappingType is either a dictionary of key to value for one-to-one mapping case and a dictionary of key to list of values for one-to-many case. Parameters ---------- mapping: dict[str, Any] Mapping to output path: :class:`~pathlib.Path` Path to save mapping skip_safe: bool, optional Flag for whether to skip over making a string safe """ if not mapping: return with mfa_open(path, "w") as f: for k in sorted(mapping.keys()): v = mapping[k] if isinstance(v, (list, set, tuple)): v = " ".join(map(str, v)) elif not skip_safe: v = make_scp_safe(v) f.write(f"{make_scp_safe(k)} {v}\n")
[docs] def load_scp( path: Path, data_type: typing.Optional[typing.Type] = str ) -> typing.Dict[str, typing.Any]: """ Load a Kaldi script file (.scp) Scp files in Kaldi can either be one-to-one or one-to-many, with the first element separated by whitespace as the key and the remaining whitespace-delimited elements the values. Returns a dictionary of key to value for one-to-one mapping case and a dictionary of key to list of values for one-to-many case. See Also -------- :kaldi_docs:`io#io_sec_scp_details` For more information on the SCP format Parameters ---------- path : :class:`~pathlib.Path` Path to Kaldi script file data_type : type Type to coerce the data to Returns ------- dict[str, Any] Dictionary where the keys are the first column and the values are all other columns in the scp file """ scp = {} with mfa_open(path, "r") as f: for line in f: line = line.strip() if line == "": continue line_list = line.split() key = load_scp_safe(line_list.pop(0)) if len(line_list) == 1: value = data_type(line_list[0]) if isinstance(value, str): value = load_scp_safe(value) else: value = [data_type(x) for x in line_list if x not in ["[", "]"]] scp[key] = value return scp
[docs] def edit_distance(x: typing.List[str], y: typing.List[str]) -> int: """ Compute edit distance between two sets of labels See Also -------- `https://gist.github.com/kylebgorman/8034009 <https://gist.github.com/kylebgorman/8034009>`_ For a more expressive version of this function Parameters ---------- x: list[str] First sequence to compare y: list[str] Second sequence to compare Returns ------- int Edit distance """ idim = len(x) + 1 jdim = len(y) + 1 table = numpy.zeros((idim, jdim), dtype=numpy.uint8) table[1:, 0] = 1 table[0, 1:] = 1 for i in range(1, idim): for j in range(1, jdim): if x[i - 1] == y[j - 1]: table[i][j] = table[i - 1][j - 1] else: c1 = table[i - 1][j] c2 = table[i][j - 1] c3 = table[i - 1][j - 1] table[i][j] = min(c1, c2, c3) + 1 return int(table[-1][-1])
def score_g2p(gold: typing.List[str], hypo: typing.List[str]) -> typing.Tuple[int, int]: """ Computes sufficient statistics for LER calculation. Parameters ---------- gold: WordData The reference labels hypo: WordData The hypothesized labels Returns ------- int Edit distance int Length of the gold labels """ for h in hypo: if h in gold: return 0, len(h) edits = 100000 best_length = 100000 for g, h in itertools.product(gold, hypo): e = edit_distance(g.split(), h.split()) if e < edits: edits = e best_length = len(g) if not edits: best_length = len(g) break return edits, best_length
[docs] def score_wer( gold: typing.List[str], hypo: typing.List[str], filter_brackets=True ) -> typing.Tuple[int, int, int, int]: """ Computes word error rate and character error rate for a transcription Parameters ---------- gold: list[str] The reference words hypo: list[str] The hypothesized words filter_brackets : bool Flag for whether to ignore bracketed words Returns ------- int Word Edit distance int Length of the gold words labels int Character edit distance int Length of the gold characters """ if filter_brackets: gold = [x for x in gold if not any(x.startswith(y) for y in "[<{")] hypo = [x for x in hypo if not any(x.startswith(y) for y in "[<{")] word_edits = edit_distance(gold, hypo) character_gold = list("".join(gold)) character_hypo = list("".join(hypo)) character_edits = edit_distance(character_gold, character_hypo) return word_edits, len(gold), character_edits, len(character_gold)
class EnhancedJSONEncoder(json.JSONEncoder): """JSON serialization""" def default(self, o: typing.Any) -> typing.Any: """Get the dictionary of a dataclass""" if dataclassy.functions.is_dataclass_instance(o): return dataclassy.asdict(o) if isinstance(o, set): return list(o) return dataclassy.asdict(o) def load_evaluation_mapping(custom_mapping_path): with mfa_open(custom_mapping_path, "r") as f: mapping = yaml.load(f, Loader=yaml.Loader) for k, v in mapping.items(): if isinstance(v, str): mapping[k] = {v} else: mapping[k] = set(v) return mapping def format_probability(probability_value: float) -> float: """Format a probability to have two decimal places and be between 0.01 and 0.99""" return min(max(round(probability_value, 2), 0.01), 0.99) def format_correction(correction_value: float, positive_only=True) -> float: """Format a probability correction value to have two decimal places and be greater than 0.01""" correction_value = round(correction_value, 2) if correction_value <= 0 and positive_only: correction_value = 0.01 return correction_value