"""
Helper functions
================
"""
from __future__ import annotations
import collections
import functools
import itertools
import json
import logging
import re
import typing
from contextlib import contextmanager
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
import dataclassy
import numpy
import yaml
from Bio import pairwise2
from rich.console import Console
from rich.logging import RichHandler
from rich.theme import Theme
if TYPE_CHECKING:
from montreal_forced_aligner.abc import MetaDict
from montreal_forced_aligner.data import CtmInterval
__all__ = [
"comma_join",
"make_safe",
"make_scp_safe",
"load_scp",
"load_scp_safe",
"score_wer",
"score_g2p",
"edit_distance",
"output_mapping",
"parse_old_features",
"compare_labels",
"overlap_scoring",
"make_re_character_set_safe",
"align_phones",
"split_phone_position",
"align_pronunciations",
"configure_logger",
"mfa_open",
"load_configuration",
"format_correction",
"format_probability",
]
console = Console(
theme=Theme(
{
"logging.level.debug": "cyan",
"logging.level.info": "green",
"logging.level.warning": "yellow",
"logging.level.error": "red",
}
),
stderr=True,
)
@contextmanager
def mfa_open(path, mode="r", encoding="utf8", newline=""):
if "r" in mode:
if "b" in mode:
file = open(path, mode)
else:
file = open(path, mode, encoding=encoding)
else:
if "b" in mode:
file = open(path, mode)
else:
file = open(path, mode, encoding=encoding, newline=newline)
try:
yield file
finally:
file.close()
def load_configuration(config_path: typing.Union[str, Path]) -> Dict[str, Any]:
"""
Load a configuration file
Parameters
----------
config_path: :class:`~pathlib.Path`
Path to yaml or json configuration file
Returns
-------
dict[str, Any]
Configuration dictionary
"""
data = {}
if not isinstance(config_path, Path):
config_path = Path(config_path)
with mfa_open(config_path, "r") as f:
if config_path.suffix == ".yaml":
data = yaml.load(f, Loader=yaml.Loader)
elif config_path.suffix == ".json":
data = json.load(f)
if not data:
return {}
return data
def split_phone_position(phone_label: str) -> List[str]:
"""
Splits a phone label into its original phone and it's positional label
Parameters
----------
phone_label: str
Phone label
Returns
-------
List[str]
Phone and position
"""
phone = phone_label
pos = None
try:
phone, pos = phone_label.rsplit("_", maxsplit=1)
except ValueError:
pass
return phone, pos
def parse_old_features(config: MetaDict) -> MetaDict:
"""
Backwards compatibility function to parse old feature configuration blocks
Parameters
----------
config: dict[str, Any]
Configuration parameters
Returns
-------
dict[str, Any]
Up to date versions of feature blocks
"""
feature_key_remapping = {
"type": "feature_type",
"deltas": "uses_deltas",
}
skip_keys = ["lda", "fmllr"]
if "features" in config:
for key in skip_keys:
if key in config["features"]:
del config["features"][key]
for key, new_key in feature_key_remapping.items():
if key in config["features"]:
config["features"][new_key] = config["features"][key]
del config["features"][key]
else:
for key in skip_keys:
if key in config:
del config[key]
for key, new_key in feature_key_remapping.items():
if key in config:
config[new_key] = config[key]
del config[key]
return config
def configure_logger(identifier: str, log_file: Optional[Path] = None) -> None:
"""
Configure logging for the given identifier
Parameters
----------
identifier: str
Logger identifier
log_file: str
Path to file to write all messages to
"""
from montreal_forced_aligner.config import MfaConfiguration
config = MfaConfiguration()
logger = logging.getLogger(identifier)
logger.setLevel(logging.DEBUG)
if log_file is not None:
file_handler = logging.FileHandler(log_file, encoding="utf8")
file_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
if not config.current_profile.quiet:
handler = RichHandler(
rich_tracebacks=True, log_time_format="", console=console, show_path=False
)
if config.current_profile.verbose:
handler.setLevel(logging.DEBUG)
else:
handler.setLevel(logging.INFO)
handler.setFormatter(logging.Formatter("%(message)s"))
logger.addHandler(handler)
[docs]
def comma_join(sequence: List[Any]) -> str:
"""
Helper function to combine a list into a human-readable expression with commas and a
final "and" separator
Parameters
----------
sequence: list[Any]
Items to join together into a list
Returns
-------
str
Joined items in list format
"""
if len(sequence) < 3:
return " and ".join(sequence)
return f"{', '.join(sequence[:-1])}, and {sequence[-1]}"
def make_re_character_set_safe(
characters: typing.Collection[str], extra_strings: Optional[List[str]] = None
) -> str:
"""
Construct a character set string for use in regex, escaping necessary characters and
moving "-" to the initial position
Parameters
----------
characters: Collection[str]
Characters to compile
extra_strings: list[str], optional
Optional other strings to put in the character class
Returns
-------
str
Character set specifier for re functions
"""
characters = sorted(characters)
extra = ""
if "-" in characters:
extra = "-"
characters = [x for x in characters if x != "-"]
if extra_strings:
extra += "".join(extra_strings)
return f"[{extra}{re.escape(''.join(characters))}]"
[docs]
def make_safe(element: Any) -> str:
"""
Helper function to make an element a string
Parameters
----------
element: Any
Element to recursively turn into a string
Returns
-------
str
All elements combined into a string
"""
if isinstance(element, list):
return " ".join(map(make_safe, element))
return str(element)
[docs]
def make_scp_safe(string: str) -> str:
"""
Helper function to make a string safe for saving in Kaldi scp files. They use space as a delimiter, so
any spaces in the string will be converted to "_MFASPACE_" to preserve them
Parameters
----------
string: str
Text to escape
Returns
-------
str
Escaped text
"""
return str(string).replace(" ", "_MFASPACE_")
[docs]
def load_scp_safe(string: str) -> str:
"""
Helper function to load previously made safe text. All instances of "_MFASPACE_" will be converted to a
regular space character
Parameters
----------
string: str
String to convert
Returns
-------
str
Converted string
"""
return string.replace("_MFASPACE_", " ")
[docs]
def output_mapping(mapping: Dict[str, Any], path: Path, skip_safe: bool = False) -> None:
"""
Helper function to save mapping information (i.e., utt2spk) in Kaldi scp format
CorpusMappingType is either a dictionary of key to value for
one-to-one mapping case and a dictionary of key to list of values for one-to-many case.
Parameters
----------
mapping: dict[str, Any]
Mapping to output
path: :class:`~pathlib.Path`
Path to save mapping
skip_safe: bool, optional
Flag for whether to skip over making a string safe
"""
if not mapping:
return
with mfa_open(path, "w") as f:
for k in sorted(mapping.keys()):
v = mapping[k]
if isinstance(v, (list, set, tuple)):
v = " ".join(map(str, v))
elif not skip_safe:
v = make_scp_safe(v)
f.write(f"{make_scp_safe(k)} {v}\n")
[docs]
def load_scp(path: Path, data_type: Optional[Type] = str) -> Dict[str, Any]:
"""
Load a Kaldi script file (.scp)
Scp files in Kaldi can either be one-to-one or one-to-many, with the first element separated by
whitespace as the key and the remaining whitespace-delimited elements the values.
Returns a dictionary of key to value for
one-to-one mapping case and a dictionary of key to list of values for one-to-many case.
See Also
--------
:kaldi_docs:`io#io_sec_scp_details`
For more information on the SCP format
Parameters
----------
path : :class:`~pathlib.Path`
Path to Kaldi script file
data_type : type
Type to coerce the data to
Returns
-------
dict[str, Any]
Dictionary where the keys are the first column and the values are all
other columns in the scp file
"""
scp = {}
with mfa_open(path, "r") as f:
for line in f:
line = line.strip()
if line == "":
continue
line_list = line.split()
key = load_scp_safe(line_list.pop(0))
if len(line_list) == 1:
value = data_type(line_list[0])
if isinstance(value, str):
value = load_scp_safe(value)
else:
value = [data_type(x) for x in line_list if x not in ["[", "]"]]
scp[key] = value
return scp
[docs]
def edit_distance(x: List[str], y: List[str]) -> int:
"""
Compute edit distance between two sets of labels
See Also
--------
`https://gist.github.com/kylebgorman/8034009 <https://gist.github.com/kylebgorman/8034009>`_
For a more expressive version of this function
Parameters
----------
x: list[str]
First sequence to compare
y: list[str]
Second sequence to compare
Returns
-------
int
Edit distance
"""
idim = len(x) + 1
jdim = len(y) + 1
table = numpy.zeros((idim, jdim), dtype=numpy.uint8)
table[1:, 0] = 1
table[0, 1:] = 1
for i in range(1, idim):
for j in range(1, jdim):
if x[i - 1] == y[j - 1]:
table[i][j] = table[i - 1][j - 1]
else:
c1 = table[i - 1][j]
c2 = table[i][j - 1]
c3 = table[i - 1][j - 1]
table[i][j] = min(c1, c2, c3) + 1
return int(table[-1][-1])
def score_g2p(gold: List[str], hypo: List[str]) -> Tuple[int, int]:
"""
Computes sufficient statistics for LER calculation.
Parameters
----------
gold: WordData
The reference labels
hypo: WordData
The hypothesized labels
Returns
-------
int
Edit distance
int
Length of the gold labels
"""
for h in hypo:
if h in gold:
return 0, len(h)
edits = 100000
best_length = 100000
for (g, h) in itertools.product(gold, hypo):
e = edit_distance(g.split(), h.split())
if e < edits:
edits = e
best_length = len(g)
if not edits:
best_length = len(g)
break
return edits, best_length
[docs]
def score_wer(gold: List[str], hypo: List[str]) -> Tuple[int, int, int, int]:
"""
Computes word error rate and character error rate for a transcription
Parameters
----------
gold: list[str]
The reference words
hypo: list[str]
The hypothesized words
Returns
-------
int
Word Edit distance
int
Length of the gold words labels
int
Character edit distance
int
Length of the gold characters
"""
word_edits = edit_distance(gold, hypo)
character_gold = list("".join(gold))
character_hypo = list("".join(hypo))
character_edits = edit_distance(character_gold, character_hypo)
return word_edits, len(gold), character_edits, len(character_gold)
[docs]
def compare_labels(
ref: str, test: str, silence_phone: str, mapping: Optional[Dict[str, str]] = None
) -> int:
"""
Parameters
----------
ref: str
test: str
mapping: Optional[dict[str, str]]
Returns
-------
int
0 if labels match or they're in mapping, 2 otherwise
"""
if ref == test:
return 0
if ref == silence_phone or test == silence_phone:
return 10
if mapping is not None and test in mapping:
if isinstance(mapping[test], str):
if mapping[test] == ref:
return 0
elif ref in mapping[test]:
return 0
ref = ref.lower()
test = test.lower()
if ref == test:
return 0
return 2
[docs]
def overlap_scoring(
first_element: CtmInterval,
second_element: CtmInterval,
silence_phone: str,
mapping: Optional[Dict[str, str]] = None,
) -> float:
r"""
Method to calculate overlap scoring
.. math::
Score = -(\lvert begin_{1} - begin_{2} \rvert + \lvert end_{1} - end_{2} \rvert + \begin{cases}
0, & if label_{1} = label_{2} \\
2, & otherwise
\end{cases})
See Also
--------
`Blog post <https://memcauliffe.com/update-on-montreal-forced-aligner-performance.html>`_
For a detailed example that using this metric
Parameters
----------
first_element: :class:`~montreal_forced_aligner.data.CtmInterval`
First CTM interval to compare
second_element: :class:`~montreal_forced_aligner.data.CtmInterval`
Second CTM interval
mapping: Optional[dict[str, str]]
Optional mapping of phones to treat as matches even if they have different symbols
Returns
-------
float
Score calculated as the negative sum of the absolute different in begin timestamps, absolute difference in end
timestamps and the label score
"""
begin_diff = abs(first_element.begin - second_element.begin)
end_diff = abs(first_element.end - second_element.end)
label_diff = compare_labels(first_element.label, second_element.label, silence_phone, mapping)
return -1 * (begin_diff + end_diff + label_diff)
class EnhancedJSONEncoder(json.JSONEncoder):
"""JSON serialization"""
def default(self, o: Any) -> Any:
"""Get the dictionary of a dataclass"""
if dataclassy.functions.is_dataclass_instance(o):
return dataclassy.asdict(o)
if isinstance(o, set):
return list(o)
return dataclassy.asdict(o)
def align_pronunciations(
ref_text: typing.List[str],
pronunciations: typing.List[str],
oov_phone: str,
silence_phone: str,
silence_word: str,
word_pronunciations: typing.Dict[str, typing.Set[str]],
):
def score_function(ref: str, pron: typing.List[str]):
if not word_pronunciations:
return 0
if ref in word_pronunciations and pron in word_pronunciations[ref]:
return 0
if pron == oov_phone:
return 0
return -2
alignments = pairwise2.align.globalcs(
ref_text,
pronunciations,
score_function,
-1 if word_pronunciations else -5,
-1 if word_pronunciations else -5,
gap_char=["-"],
one_alignment_only=True,
)
transformed_pronunciations = []
for a in alignments:
for i, sa in enumerate(a.seqA):
sb = a.seqB[i]
if sa == "-" and sb == silence_phone:
sa = silence_word
if "-" in (sa, sb):
continue
transformed_pronunciations.append((sa, sb.split()))
return transformed_pronunciations
[docs]
def align_phones(
ref: List[CtmInterval],
test: List[CtmInterval],
silence_phone: str,
ignored_phones: typing.Set[str] = None,
custom_mapping: Optional[Dict[str, str]] = None,
debug: bool = False,
) -> Tuple[float, float, Dict[Tuple[str, str], int]]:
"""
Align phones based on how much they overlap and their phone label, with the ability to specify a custom mapping for
different phone labels to be scored as if they're the same phone
Parameters
----------
ref: list[:class:`~montreal_forced_aligner.data.CtmInterval`]
List of CTM intervals as reference
test: list[:class:`~montreal_forced_aligner.data.CtmInterval`]
List of CTM intervals to compare to reference
silence_phone: str
Silence phone (these are ignored in the final calculation)
custom_mapping: dict[str, str], optional
Mapping of phones to treat as matches even if they have different symbols
debug: bool, optional
Flag for logging extra information about alignments
Returns
-------
float
Score based on the average amount of overlap in phone intervals
float
Phone error rate
dict[tuple[str, str], int]
Dictionary of error pairs with their counts
"""
if ignored_phones is None:
ignored_phones = set()
if custom_mapping is None:
score_func = functools.partial(overlap_scoring, silence_phone=silence_phone)
else:
score_func = functools.partial(
overlap_scoring, silence_phone=silence_phone, mapping=custom_mapping
)
alignments = pairwise2.align.globalcs(
ref, test, score_func, -2, -2, gap_char=["-"], one_alignment_only=True
)
overlap_count = 0
overlap_sum = 0
num_insertions = 0
num_deletions = 0
num_substitutions = 0
errors = collections.Counter()
for a in alignments:
for i, sa in enumerate(a.seqA):
sb = a.seqB[i]
if sa == "-":
if sb.label not in ignored_phones:
errors[(sa, sb.label)] += 1
num_insertions += 1
else:
continue
elif sb == "-":
if sa.label not in ignored_phones:
errors[(sa.label, sb)] += 1
num_deletions += 1
else:
continue
else:
if sa.label in ignored_phones:
continue
overlap_sum += (abs(sa.begin - sb.begin) + abs(sa.end - sb.end)) / 2
overlap_count += 1
if compare_labels(sa.label, sb.label, silence_phone, mapping=custom_mapping) > 0:
num_substitutions += 1
errors[(sa.label, sb.label)] += 1
if debug:
import logging
logger = logging.getLogger("mfa")
logger.debug(pairwise2.format_alignment(*alignments[0]))
if overlap_count:
score = overlap_sum / overlap_count
else:
score = None
phone_error_rate = (num_insertions + num_deletions + (2 * num_substitutions)) / len(ref)
return score, phone_error_rate, errors
def format_probability(probability_value: float) -> float:
"""Format a probability to have two decimal places and be between 0.01 and 0.99"""
return min(max(round(probability_value, 2), 0.01), 0.99)
def format_correction(correction_value: float, positive_only=True) -> float:
"""Format a probability correction value to have two decimal places and be greater than 0.01"""
correction_value = round(correction_value, 2)
if correction_value <= 0 and positive_only:
correction_value = 0.01
return correction_value