"""
Helper functions
================
"""
from __future__ import annotations
import collections
import functools
import itertools
import json
import logging
import re
import typing
import warnings
from contextlib import contextmanager
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
import dataclassy
import numpy
import yaml
from rich.console import Console
from rich.logging import RichHandler
from rich.theme import Theme
with warnings.catch_warnings():
warnings.simplefilter("ignore")
from Bio import pairwise2
if TYPE_CHECKING:
from kalpy.fstext.lexicon import LexiconCompiler
from montreal_forced_aligner.abc import MetaDict
from montreal_forced_aligner.data import CtmInterval
__all__ = [
"comma_join",
"make_safe",
"make_scp_safe",
"load_scp",
"load_scp_safe",
"score_wer",
"score_g2p",
"edit_distance",
"output_mapping",
"parse_old_features",
"compare_labels",
"overlap_scoring",
"make_re_character_set_safe",
"align_phones",
"split_phone_position",
"align_pronunciations",
"configure_logger",
"mfa_open",
"load_configuration",
"format_correction",
"format_probability",
"load_evaluation_mapping",
]
console = Console(
theme=Theme(
{
"logging.level.debug": "cyan",
"logging.level.info": "green",
"logging.level.warning": "yellow",
"logging.level.error": "red",
}
),
stderr=True,
)
@contextmanager
def mfa_open(path, mode="r", encoding="utf8", newline=""):
if "r" in mode:
if "b" in mode:
file = open(path, mode)
else:
file = open(path, mode, encoding=encoding)
else:
if "b" in mode:
file = open(path, mode)
else:
file = open(path, mode, encoding=encoding, newline=newline)
try:
yield file
finally:
file.close()
def load_configuration(config_path: typing.Union[str, Path]) -> Dict[str, Any]:
"""
Load a configuration file
Parameters
----------
config_path: :class:`~pathlib.Path`
Path to yaml or json configuration file
Returns
-------
dict[str, Any]
Configuration dictionary
"""
data = {}
if not isinstance(config_path, Path):
config_path = Path(config_path)
with mfa_open(config_path, "r") as f:
if config_path.suffix == ".yaml":
data = yaml.load(f, Loader=yaml.Loader)
elif config_path.suffix == ".json":
data = json.load(f)
if not data:
return {}
return data
def split_phone_position(phone_label: str) -> List[str]:
"""
Splits a phone label into its original phone and it's positional label
Parameters
----------
phone_label: str
Phone label
Returns
-------
List[str]
Phone and position
"""
phone = phone_label
pos = None
try:
phone, pos = phone_label.rsplit("_", maxsplit=1)
except ValueError:
pass
return phone, pos
def parse_old_features(config: MetaDict) -> MetaDict:
"""
Backwards compatibility function to parse old feature configuration blocks
Parameters
----------
config: dict[str, Any]
Configuration parameters
Returns
-------
dict[str, Any]
Up to date versions of feature blocks
"""
feature_key_remapping = {
"type": "feature_type",
"deltas": "uses_deltas",
}
skip_keys = ["lda", "fmllr"]
if "features" in config:
for key in skip_keys:
if key in config["features"]:
del config["features"][key]
for key, new_key in feature_key_remapping.items():
if key in config["features"]:
config["features"][new_key] = config["features"][key]
del config["features"][key]
else:
for key in skip_keys:
if key in config:
del config[key]
for key, new_key in feature_key_remapping.items():
if key in config:
config[new_key] = config[key]
del config[key]
return config
def configure_logger(identifier: str, log_file: Optional[Path] = None) -> None:
"""
Configure logging for the given identifier
Parameters
----------
identifier: str
Logger identifier
log_file: str
Path to file to write all messages to
"""
from montreal_forced_aligner.config import MfaConfiguration
config = MfaConfiguration()
logger = logging.getLogger(identifier)
logger.setLevel(logging.DEBUG)
if log_file is not None:
file_handler = logging.FileHandler(log_file, encoding="utf8")
file_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
if not config.current_profile.quiet:
handler = RichHandler(
rich_tracebacks=True, log_time_format="", console=console, show_path=False
)
if config.current_profile.verbose:
handler.setLevel(logging.DEBUG)
else:
handler.setLevel(logging.INFO)
handler.setFormatter(logging.Formatter("%(message)s"))
logger.addHandler(handler)
[docs]
def comma_join(sequence: List[Any]) -> str:
"""
Helper function to combine a list into a human-readable expression with commas and a
final "and" separator
Parameters
----------
sequence: list[Any]
Items to join together into a list
Returns
-------
str
Joined items in list format
"""
if len(sequence) < 3:
return " and ".join(sequence)
return f"{', '.join(sequence[:-1])}, and {sequence[-1]}"
def make_re_character_set_safe(
characters: typing.Collection[str], extra_strings: Optional[List[str]] = None
) -> str:
"""
Construct a character set string for use in regex, escaping necessary characters and
moving "-" to the initial position
Parameters
----------
characters: Collection[str]
Characters to compile
extra_strings: list[str], optional
Optional other strings to put in the character class
Returns
-------
str
Character set specifier for re functions
"""
characters = sorted(characters)
extra = ""
if "-" in characters:
extra = "-"
characters = [x for x in characters if x != "-"]
if extra_strings:
extra += "".join(extra_strings)
return f"[{extra}{re.escape(''.join(characters))}]"
[docs]
def make_safe(element: Any) -> str:
"""
Helper function to make an element a string
Parameters
----------
element: Any
Element to recursively turn into a string
Returns
-------
str
All elements combined into a string
"""
if isinstance(element, list):
return " ".join(map(make_safe, element))
return str(element)
[docs]
def make_scp_safe(string: str) -> str:
"""
Helper function to make a string safe for saving in Kaldi scp files. They use space as a delimiter, so
any spaces in the string will be converted to "_MFASPACE_" to preserve them
Parameters
----------
string: str
Text to escape
Returns
-------
str
Escaped text
"""
return str(string).replace(" ", "_MFASPACE_")
[docs]
def load_scp_safe(string: str) -> str:
"""
Helper function to load previously made safe text. All instances of "_MFASPACE_" will be converted to a
regular space character
Parameters
----------
string: str
String to convert
Returns
-------
str
Converted string
"""
return string.replace("_MFASPACE_", " ")
[docs]
def output_mapping(mapping: Dict[str, Any], path: Path, skip_safe: bool = False) -> None:
"""
Helper function to save mapping information (i.e., utt2spk) in Kaldi scp format
CorpusMappingType is either a dictionary of key to value for
one-to-one mapping case and a dictionary of key to list of values for one-to-many case.
Parameters
----------
mapping: dict[str, Any]
Mapping to output
path: :class:`~pathlib.Path`
Path to save mapping
skip_safe: bool, optional
Flag for whether to skip over making a string safe
"""
if not mapping:
return
with mfa_open(path, "w") as f:
for k in sorted(mapping.keys()):
v = mapping[k]
if isinstance(v, (list, set, tuple)):
v = " ".join(map(str, v))
elif not skip_safe:
v = make_scp_safe(v)
f.write(f"{make_scp_safe(k)} {v}\n")
[docs]
def load_scp(path: Path, data_type: Optional[Type] = str) -> Dict[str, Any]:
"""
Load a Kaldi script file (.scp)
Scp files in Kaldi can either be one-to-one or one-to-many, with the first element separated by
whitespace as the key and the remaining whitespace-delimited elements the values.
Returns a dictionary of key to value for
one-to-one mapping case and a dictionary of key to list of values for one-to-many case.
See Also
--------
:kaldi_docs:`io#io_sec_scp_details`
For more information on the SCP format
Parameters
----------
path : :class:`~pathlib.Path`
Path to Kaldi script file
data_type : type
Type to coerce the data to
Returns
-------
dict[str, Any]
Dictionary where the keys are the first column and the values are all
other columns in the scp file
"""
scp = {}
with mfa_open(path, "r") as f:
for line in f:
line = line.strip()
if line == "":
continue
line_list = line.split()
key = load_scp_safe(line_list.pop(0))
if len(line_list) == 1:
value = data_type(line_list[0])
if isinstance(value, str):
value = load_scp_safe(value)
else:
value = [data_type(x) for x in line_list if x not in ["[", "]"]]
scp[key] = value
return scp
[docs]
def edit_distance(x: List[str], y: List[str]) -> int:
"""
Compute edit distance between two sets of labels
See Also
--------
`https://gist.github.com/kylebgorman/8034009 <https://gist.github.com/kylebgorman/8034009>`_
For a more expressive version of this function
Parameters
----------
x: list[str]
First sequence to compare
y: list[str]
Second sequence to compare
Returns
-------
int
Edit distance
"""
idim = len(x) + 1
jdim = len(y) + 1
table = numpy.zeros((idim, jdim), dtype=numpy.uint8)
table[1:, 0] = 1
table[0, 1:] = 1
for i in range(1, idim):
for j in range(1, jdim):
if x[i - 1] == y[j - 1]:
table[i][j] = table[i - 1][j - 1]
else:
c1 = table[i - 1][j]
c2 = table[i][j - 1]
c3 = table[i - 1][j - 1]
table[i][j] = min(c1, c2, c3) + 1
return int(table[-1][-1])
def score_g2p(gold: List[str], hypo: List[str]) -> Tuple[int, int]:
"""
Computes sufficient statistics for LER calculation.
Parameters
----------
gold: WordData
The reference labels
hypo: WordData
The hypothesized labels
Returns
-------
int
Edit distance
int
Length of the gold labels
"""
for h in hypo:
if h in gold:
return 0, len(h)
edits = 100000
best_length = 100000
for g, h in itertools.product(gold, hypo):
e = edit_distance(g.split(), h.split())
if e < edits:
edits = e
best_length = len(g)
if not edits:
best_length = len(g)
break
return edits, best_length
[docs]
def score_wer(gold: List[str], hypo: List[str]) -> Tuple[int, int, int, int]:
"""
Computes word error rate and character error rate for a transcription
Parameters
----------
gold: list[str]
The reference words
hypo: list[str]
The hypothesized words
Returns
-------
int
Word Edit distance
int
Length of the gold words labels
int
Character edit distance
int
Length of the gold characters
"""
word_edits = edit_distance(gold, hypo)
character_gold = list("".join(gold))
character_hypo = list("".join(hypo))
character_edits = edit_distance(character_gold, character_hypo)
return word_edits, len(gold), character_edits, len(character_gold)
[docs]
def compare_labels(
ref: str, test: str, silence_phone: str, mapping: Optional[Dict[str, str]] = None
) -> int:
"""
Parameters
----------
ref: str
test: str
mapping: Optional[dict[str, str]]
Returns
-------
int
0 if labels match or they're in mapping, 2 otherwise
"""
if ref == test:
return 0
if ref == silence_phone or test == silence_phone:
return 10
if mapping is not None and test in mapping:
if isinstance(mapping[test], str):
if mapping[test] == ref:
return 0
elif ref in mapping[test]:
return 0
ref = ref.lower()
test = test.lower()
if ref == test:
return 0
return 2
[docs]
def overlap_scoring(
first_element: CtmInterval,
second_element: CtmInterval,
silence_phone: str,
mapping: Optional[Dict[str, str]] = None,
) -> float:
r"""
Method to calculate overlap scoring
.. math::
Score = -(\lvert begin_{1} - begin_{2} \rvert + \lvert end_{1} - end_{2} \rvert + \begin{cases}
0, & if label_{1} = label_{2} \\
2, & otherwise
\end{cases})
See Also
--------
`Blog post <https://memcauliffe.com/update-on-montreal-forced-aligner-performance.html>`_
For a detailed example that using this metric
Parameters
----------
first_element: :class:`~montreal_forced_aligner.data.CtmInterval`
First CTM interval to compare
second_element: :class:`~montreal_forced_aligner.data.CtmInterval`
Second CTM interval
mapping: Optional[dict[str, str]]
Optional mapping of phones to treat as matches even if they have different symbols
Returns
-------
float
Score calculated as the negative sum of the absolute different in begin timestamps, absolute difference in end
timestamps and the label score
"""
begin_diff = abs(first_element.begin - second_element.begin)
end_diff = abs(first_element.end - second_element.end)
label_diff = compare_labels(first_element.label, second_element.label, silence_phone, mapping)
return -1 * (begin_diff + end_diff + label_diff)
class EnhancedJSONEncoder(json.JSONEncoder):
"""JSON serialization"""
def default(self, o: Any) -> Any:
"""Get the dictionary of a dataclass"""
if dataclassy.functions.is_dataclass_instance(o):
return dataclassy.asdict(o)
if isinstance(o, set):
return list(o)
return dataclassy.asdict(o)
def align_pronunciations(
ref_text: typing.List[str],
pronunciations: typing.List[str],
oov_phone: str,
silence_phone: str,
silence_word: str,
word_pronunciations: typing.Dict[str, typing.Set[str]],
):
def score_function(ref: str, pron: typing.List[str]):
if not word_pronunciations:
return 0
if ref in word_pronunciations and pron in word_pronunciations[ref]:
return 0
if pron == oov_phone:
return 0
return -2
alignments = pairwise2.align.globalcs(
ref_text,
pronunciations,
score_function,
-1 if word_pronunciations else -5,
-1 if word_pronunciations else -5,
gap_char=["-"],
one_alignment_only=True,
)
transformed_pronunciations = []
for a in alignments:
for i, sa in enumerate(a.seqA):
sb = a.seqB[i]
if sa == "-" and sb == silence_phone:
sa = silence_word
if "-" in (sa, sb):
continue
transformed_pronunciations.append((sa, sb.split()))
return transformed_pronunciations
def load_evaluation_mapping(custom_mapping_path):
with mfa_open(custom_mapping_path, "r") as f:
mapping = yaml.load(f, Loader=yaml.Loader)
for k, v in mapping.items():
if isinstance(v, str):
mapping[k] = {v}
else:
mapping[k] = set(v)
return mapping
def fix_many_to_one_alignments(alignments, custom_mapping):
test_keys = set(x for x in custom_mapping.keys() if " " in x)
ref_keys = set()
for val in custom_mapping.values():
ref_keys.update(x for x in val if " " in x)
new_ref = []
new_test = []
for a in alignments:
for i, sa in enumerate(a.seqA):
sb = a.seqB[i]
if i != 0:
prev_sa = a.seqA[i - 1]
prev_sb = a.seqB[i - 1]
ref_key = " ".join(x.label for x in [prev_sa, sa] if x != "-")
test_key = " ".join(x.label for x in [prev_sb, sb] if x != "-")
if (
ref_key in ref_keys
and test_key in custom_mapping
and ref_key in custom_mapping[test_key]
):
new_ref[-1].label = ref_key
new_ref[-1].end = sa.end
if sb != "-":
new_test.append(sb)
continue
if (
test_key in test_keys
and test_key in custom_mapping
and ref_key in custom_mapping[test_key]
):
new_test[-1].label = test_key
new_test[-1].end = sb.end
if sa != "-":
new_ref.append(sa)
continue
if sa != "-":
new_ref.append(sa)
if sb != "-":
new_test.append(sb)
return new_ref, new_test
[docs]
def align_phones(
ref: List[CtmInterval],
test: List[CtmInterval],
silence_phone: str,
ignored_phones: typing.Set[str] = None,
custom_mapping: Optional[Dict[str, str]] = None,
debug: bool = False,
) -> Tuple[float, float, Dict[Tuple[str, str], int]]:
"""
Align phones based on how much they overlap and their phone label, with the ability to specify a custom mapping for
different phone labels to be scored as if they're the same phone
Parameters
----------
ref: list[:class:`~montreal_forced_aligner.data.CtmInterval`]
List of CTM intervals as reference
test: list[:class:`~montreal_forced_aligner.data.CtmInterval`]
List of CTM intervals to compare to reference
silence_phone: str
Silence phone (these are ignored in the final calculation)
ignored_phones: set[str], optional
Phones that should be ignored in score calculations (silence phone is automatically added)
custom_mapping: dict[str, str], optional
Mapping of phones to treat as matches even if they have different symbols
debug: bool, optional
Flag for logging extra information about alignments
Returns
-------
float
Score based on the average amount of overlap in phone intervals
float
Phone error rate
dict[tuple[str, str], int]
Dictionary of error pairs with their counts
"""
if ignored_phones is None:
ignored_phones = set()
if not isinstance(ignored_phones, set):
ignored_phones = set(ignored_phones)
if custom_mapping is None:
score_func = functools.partial(overlap_scoring, silence_phone=silence_phone)
else:
score_func = functools.partial(
overlap_scoring, silence_phone=silence_phone, mapping=custom_mapping
)
alignments = pairwise2.align.globalcs(
ref, test, score_func, -2, -2, gap_char=["-"], one_alignment_only=True
)
if custom_mapping is not None:
ref, test = fix_many_to_one_alignments(alignments, custom_mapping)
alignments = pairwise2.align.globalcs(
ref, test, score_func, -2, -2, gap_char=["-"], one_alignment_only=True
)
overlap_count = 0
overlap_sum = 0
num_insertions = 0
num_deletions = 0
num_substitutions = 0
errors = collections.Counter()
ignored_phones.add(silence_phone)
for a in alignments:
for i, sa in enumerate(a.seqA):
sb = a.seqB[i]
if sa == "-":
if sb.label not in ignored_phones:
errors[(sa, sb.label)] += 1
num_insertions += 1
else:
continue
elif sb == "-":
if sa.label not in ignored_phones:
errors[(sa.label, sb)] += 1
num_deletions += 1
else:
continue
else:
if sa.label in ignored_phones:
continue
overlap_sum += (abs(sa.begin - sb.begin) + abs(sa.end - sb.end)) / 2
overlap_count += 1
if compare_labels(sa.label, sb.label, silence_phone, mapping=custom_mapping) > 0:
num_substitutions += 1
errors[(sa.label, sb.label)] += 1
if overlap_count:
score = overlap_sum / overlap_count
else:
score = None
phone_error_rate = (num_insertions + num_deletions + (2 * num_substitutions)) / len(ref)
if debug:
import logging
logger = logging.getLogger("mfa")
logger.debug(
f"{pairwise2.format_alignment(*alignments[0])}\nScore: {score}\nPER: {phone_error_rate}\nErrors: {errors}"
)
return score, phone_error_rate, errors
def fix_unk_words(
ref: List[str],
test: List[CtmInterval],
lexicon_compiler: LexiconCompiler,
) -> Tuple[float, float, Dict[Tuple[str, str], int]]:
"""
Align phones based on how much they overlap and their phone label, with the ability to specify a custom mapping for
different phone labels to be scored as if they're the same phone
Parameters
----------
ref: list[:class:`~montreal_forced_aligner.data.CtmInterval`]
List of CTM intervals as reference
test: list[:class:`~montreal_forced_aligner.data.CtmInterval`]
List of CTM intervals to compare to reference
lexicon_compiler: LexiconCompiler
Lexicon compiler to use for evaluating the identity of OOV items
Returns
-------
float
Extra duration of new words
float
Word error rate
float
Aligned duration of found words
"""
from kalpy.gmm.data import WordCtmInterval
def score_func(ref, test):
ref_label = ref
if isinstance(ref_label, WordCtmInterval):
ref_label = ref_label.label
test_label = test
if isinstance(test_label, WordCtmInterval):
test_label = test_label.label
if ref_label == test_label:
return 0
if (
test_label == lexicon_compiler.silence_word
or ref_label == lexicon_compiler.silence_word
):
return -10
if lexicon_compiler.to_int(ref_label) == lexicon_compiler.to_int(test_label):
return 0
return -2
alignments = pairwise2.align.globalcs(
ref, test, score_func, -2, -2, gap_char=["-"], one_alignment_only=True
)
output_ctm = []
for a in alignments:
for i, sa in enumerate(a.seqA):
sb = a.seqB[i]
if sa == "-":
output_ctm.append(sb)
elif sb == "-":
continue
else:
if sa != sb.label and sb.label == lexicon_compiler.oov_word:
sb.label = sa
output_ctm.append(sb)
return output_ctm
def align_words(
ref: List[str],
test: List[CtmInterval],
silence_word: str,
ignored_words: typing.Set[str] = None,
debug: bool = False,
) -> Tuple[float, float, Dict[Tuple[str, str], int]]:
"""
Align phones based on how much they overlap and their phone label, with the ability to specify a custom mapping for
different phone labels to be scored as if they're the same phone
Parameters
----------
ref: list[:class:`~montreal_forced_aligner.data.CtmInterval`]
List of CTM intervals as reference
test: list[:class:`~montreal_forced_aligner.data.CtmInterval`]
List of CTM intervals to compare to reference
silence_word: str
Silence word (these are ignored in the final calculation)
ignored_words: set[str], optional
Words that should be ignored in score calculations (silence phone is automatically added)
debug: bool, optional
Flag for logging extra information about alignments
Returns
-------
float
Extra duration of new words
float
Word error rate
float
Aligned duration of found words
"""
from montreal_forced_aligner.data import CtmInterval
if ignored_words is None:
ignored_words = set()
if not isinstance(ignored_words, set):
ignored_words = set(ignored_words)
def score_func(ref, test):
ref_label = ref
if isinstance(ref_label, CtmInterval):
ref_label = ref_label.label
test_label = test
if isinstance(test_label, CtmInterval):
test_label = test_label.label
if ref_label == test_label:
return 0
if test_label == silence_word or ref_label == silence_word:
return -10
return -2
alignments = pairwise2.align.globalcs(
ref, test, score_func, -2, -2, gap_char=["-"], one_alignment_only=True
)
num_insertions = 0
num_deletions = 0
num_substitutions = 0
ignored_words.add(silence_word)
extra_duration = 0
aligned_duration = 0
for a in alignments:
for i, sa in enumerate(a.seqA):
sb = a.seqB[i]
if sa == "-":
if sb.label not in ignored_words:
num_insertions += 1
extra_duration += sb.end - sb.begin
else:
continue
elif sb == "-":
if sa not in ignored_words:
num_deletions += 1
else:
continue
else:
if sa in ignored_words:
continue
if sa != sb.label:
num_substitutions += 1
else:
aligned_duration += sb.end - sb.begin
word_error_rate = (num_insertions + num_deletions + (2 * num_substitutions)) / len(ref)
if debug:
import logging
logger = logging.getLogger("mfa")
logger.debug(
f"{pairwise2.format_alignment(*alignments[0])}\nExtra word duration: {extra_duration}\nWER: {word_error_rate}"
)
return extra_duration, word_error_rate, aligned_duration
def format_probability(probability_value: float) -> float:
"""Format a probability to have two decimal places and be between 0.01 and 0.99"""
return min(max(round(probability_value, 2), 0.01), 0.99)
def format_correction(correction_value: float, positive_only=True) -> float:
"""Format a probability correction value to have two decimal places and be greater than 0.01"""
correction_value = round(correction_value, 2)
if correction_value <= 0 and positive_only:
correction_value = 0.01
return correction_value