"""Class for generating pronunciations from G2P models"""
from __future__ import annotations
import csv
import functools
import logging
import multiprocessing as mp
import os
import queue
import statistics
import time
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
import pynini
import tqdm
from pynini import Fst, TokenType
from pynini.lib import rewrite
from pywrapfst import SymbolTable
from montreal_forced_aligner.abc import DatabaseMixin, TopLevelMfaWorker
from montreal_forced_aligner.config import GLOBAL_CONFIG
from montreal_forced_aligner.corpus.text_corpus import TextCorpusMixin
from montreal_forced_aligner.exceptions import PyniniGenerationError
from montreal_forced_aligner.g2p.mixins import G2PTopLevelMixin
from montreal_forced_aligner.helper import comma_join, mfa_open, score_g2p
from montreal_forced_aligner.models import G2PModel
from montreal_forced_aligner.utils import Stopped
if TYPE_CHECKING:
SpeakerCharacterType = Union[str, int]
__all__ = [
"Rewriter",
"RewriterWorker",
"PyniniGenerator",
"PyniniCorpusGenerator",
"PyniniWordListGenerator",
]
logger = logging.getLogger("mfa")
def threshold_lattice_to_dfa(
lattice: pynini.Fst, threshold: float = 1.0, state_multiplier: int = 2
) -> pynini.Fst:
"""Constructs a (possibly pruned) weighted DFA of output strings.
Given an epsilon-free lattice of output strings (such as produced by
rewrite_lattice), attempts to determinize it, pruning non-optimal paths if
optimal_only is true. This is valid only in a semiring with the path property.
To prevent unexpected blowup during determinization, a state threshold is
also used and a warning is logged if this exact threshold is reached. The
threshold is a multiplier of the size of input lattice (by default, 4), plus
a small constant factor. This is intended by a sensible default and is not an
inherently meaningful value in and of itself.
Parameters
----------
lattice: :class:`~pynini.Fst`
Epsilon-free non-deterministic finite acceptor.
threshold: float
Threshold for weights, 1.0 is optimal only, 0 is for all paths, greater than 1
prunes the lattice to include paths with costs less than the optimal path's score times the threshold
state_multiplier: int
Max ratio for the number of states in the DFA lattice to the NFA lattice; if exceeded, a warning is logged.
Returns
-------
:class:`~pynini.Fst`
Epsilon-free deterministic finite acceptor.
"""
weight_type = lattice.weight_type()
weight_threshold = pynini.Weight(weight_type, threshold)
state_threshold = 256 + state_multiplier * lattice.num_states()
lattice = pynini.determinize(lattice, nstate=state_threshold, weight=weight_threshold)
return lattice
def optimal_rewrites(
string: pynini.FstLike,
rule: pynini.Fst,
input_token_type: Optional[TokenType] = None,
output_token_type: Optional[TokenType] = None,
threshold: float = 1,
) -> List[str]:
"""Returns all optimal rewrites.
Args:
string: Input string or FST.
rule: Input rule WFST.
input_token_type: Optional input token type, or symbol table.
output_token_type: Optional output token type, or symbol table.
threshold: Threshold for weights (1 is optimal only, 0 is for all paths)
Returns:
A tuple of output strings.
"""
lattice = rewrite.rewrite_lattice(string, rule, input_token_type)
lattice = threshold_lattice_to_dfa(lattice, threshold, 4)
return rewrite.lattice_to_strings(lattice, output_token_type)
[docs]
class Rewriter:
"""
Helper object for rewriting
Parameters
----------
fst: pynini.Fst
G2P FST model
input_token_type: pynini.TokenType
Grapheme symbol table or "utf8"
output_token_type: pynini.SymbolTable
Phone symbol table
num_pronunciations: int
Number of pronunciations, default to 0. If this is 0, thresholding is used
threshold: float
Threshold to use for pruning rewrite lattice, defaults to 1.5, only used if num_pronunciations is 0
"""
def __init__(
self,
fst: Fst,
input_token_type: TokenType,
output_token_type: SymbolTable,
num_pronunciations: int = 0,
threshold: float = 1,
):
if num_pronunciations > 0:
self.rewrite = functools.partial(
rewrite.top_rewrites,
nshortest=num_pronunciations,
rule=fst,
input_token_type=input_token_type,
output_token_type=output_token_type,
)
else:
self.rewrite = functools.partial(
optimal_rewrites,
threshold=threshold,
rule=fst,
input_token_type=input_token_type,
output_token_type=output_token_type,
)
def __call__(self, i: str) -> List[Tuple[str, ...]]: # pragma: no cover
"""Call the rewrite function"""
hypotheses = self.rewrite(i)
return [x for x in hypotheses if x]
class PhonetisaurusRewriter:
"""
Helper function for rewriting
Parameters
----------
fst: pynini.Fst
G2P FST model
input_token_type: pynini.SymbolTable
Grapheme symbol table
output_token_type: pynini.SymbolTable
num_pronunciations: int
Number of pronunciations, default to 0. If this is 0, thresholding is used
threshold: float
Threshold to use for pruning rewrite lattice, defaults to 1.5, only used if num_pronunciations is 0
grapheme_order: int
Maximum number of graphemes to consider single segment
seq_sep: str
Separator to use between grapheme symbols
"""
def __init__(
self,
fst: Fst,
input_token_type: SymbolTable,
output_token_type: SymbolTable,
num_pronunciations: int = 0,
threshold: float = 1.5,
grapheme_order: int = 2,
seq_sep: str = "|",
):
self.fst = fst
self.seq_sep = seq_sep
self.input_token_type = input_token_type
self.output_token_type = output_token_type
self.grapheme_order = grapheme_order
if num_pronunciations > 0:
self.rewrite = functools.partial(
rewrite.top_rewrites,
nshortest=num_pronunciations,
rule=fst,
input_token_type=None,
output_token_type=output_token_type,
)
else:
self.rewrite = functools.partial(
optimal_rewrites,
threshold=threshold,
rule=fst,
input_token_type=None,
output_token_type=output_token_type,
)
def __call__(self, graphemes: str) -> List[Tuple[str, ...]]: # pragma: no cover
"""Call the rewrite function"""
fst = pynini.Fst()
one = pynini.Weight.one(fst.weight_type())
max_state = 0
for i in range(len(graphemes)):
start_state = fst.add_state()
for j in range(1, self.grapheme_order + 1):
if i + j <= len(graphemes):
substring = self.seq_sep.join(graphemes[i : i + j])
ilabel = self.input_token_type.find(substring)
if ilabel != pynini.NO_LABEL:
fst.add_arc(start_state, pynini.Arc(ilabel, ilabel, one, i + j))
if i + j >= max_state:
max_state = i + j
for _ in range(fst.num_states(), max_state + 1):
fst.add_state()
fst.set_start(0)
fst.set_final(len(graphemes), one)
fst.set_input_symbols(self.input_token_type)
fst.set_output_symbols(self.input_token_type)
hypotheses = self.rewrite(fst)
hypotheses = [x.replace(self.seq_sep, " ") for x in hypotheses if x]
return hypotheses
[docs]
class RewriterWorker(mp.Process):
"""
Rewriter process
Parameters
----------
job_queue: :class:`~multiprocessing.Queue`
Queue to pull words from
return_queue: :class:`~multiprocessing.Queue`
Queue to put pronunciations
rewriter: :class:`~montreal_forced_aligner.g2p.generator.Rewriter`
Function to generate pronunciations of words
stopped: :class:`~montreal_forced_aligner.utils.Stopped`
Stop check
"""
def __init__(
self,
job_queue: mp.Queue,
return_queue: mp.Queue,
rewriter: Rewriter,
stopped: Stopped,
):
mp.Process.__init__(self)
self.job_queue = job_queue
self.return_queue = return_queue
self.rewriter = rewriter
self.stopped = stopped
self.finished = Stopped()
[docs]
def run(self) -> None:
"""Run the rewriting function"""
while True:
try:
word = self.job_queue.get(timeout=1)
except queue.Empty:
break
if self.stopped.stop_check():
continue
try:
rep = self.rewriter(word)
self.return_queue.put((word, rep))
except rewrite.Error:
pass
except Exception as e: # noqa
self.stopped.stop()
self.return_queue.put(e)
raise
self.finished.stop()
return
def clean_up_word(word: str, graphemes: Set[str]) -> Tuple[str, Set[str]]:
"""
Clean up word by removing graphemes not in a specified set
Parameters
----------
word : str
Input string
graphemes: set[str]
Set of allowable graphemes
Returns
-------
str
Cleaned up word
Set[str]
Graphemes excluded
"""
new_word = []
missing_graphemes = set()
for c in word:
if c not in graphemes:
missing_graphemes.add(c)
else:
new_word.append(c)
return "".join(new_word), missing_graphemes
class OrthographyGenerator(G2PTopLevelMixin):
"""
Abstract mixin class for generating "pronunciations" based off the orthographic word
See Also
--------
:class:`~montreal_forced_aligner.g2p.mixins.G2PTopLevelMixin`
For top level G2P generation parameters
"""
def generate_pronunciations(self) -> Dict[str, List[str]]:
"""
Generate pronunciations for the word set
Returns
-------
dict[str, Word]
Mapping of words to their "pronunciation"
"""
pronunciations = {}
for word in self.words_to_g2p:
pronunciations[word] = [" ".join(word)]
return pronunciations
[docs]
class PyniniGenerator(G2PTopLevelMixin):
"""
Class for generating pronunciations from a Pynini G2P model
Parameters
----------
g2p_model_path: str
Path to G2P model
strict_graphemes: bool
Flag for whether to be strict with missing graphemes and skip words containing new graphemes
See Also
--------
:class:`~montreal_forced_aligner.g2p.mixins.G2PTopLevelMixin`
For top level G2P generation parameters
Attributes
----------
g2p_model: G2PModel
G2P model
"""
def __init__(self, g2p_model_path: str, strict_graphemes: bool = False, **kwargs):
self.strict_graphemes = strict_graphemes
super().__init__(**kwargs)
self.g2p_model = G2PModel(
g2p_model_path, root_directory=getattr(self, "workflow_directory", None)
)
self.output_token_type = "utf8"
self.input_token_type = "utf8"
self.rewriter = None
def setup(self):
self.fst = pynini.Fst.read(self.g2p_model.fst_path)
if self.g2p_model.meta["architecture"] == "phonetisaurus":
self.output_token_type = pynini.SymbolTable.read_text(self.g2p_model.sym_path)
self.input_token_type = pynini.SymbolTable.read_text(self.g2p_model.grapheme_sym_path)
self.fst.set_input_symbols(self.input_token_type)
self.fst.set_output_symbols(self.output_token_type)
self.rewriter = PhonetisaurusRewriter(
self.fst,
self.input_token_type,
self.output_token_type,
num_pronunciations=self.num_pronunciations,
threshold=self.g2p_threshold,
grapheme_order=self.g2p_model.meta["grapheme_order"],
)
else:
if self.g2p_model.sym_path is not None and os.path.exists(self.g2p_model.sym_path):
self.output_token_type = pynini.SymbolTable.read_text(self.g2p_model.sym_path)
self.rewriter = Rewriter(
self.fst,
self.input_token_type,
self.output_token_type,
num_pronunciations=self.num_pronunciations,
threshold=self.g2p_threshold,
)
[docs]
def generate_pronunciations(self) -> Dict[str, List[str]]:
"""
Generate pronunciations
Returns
-------
dict[str, list[str]]
Mappings of keys to their generated pronunciations
"""
num_words = len(self.words_to_g2p)
begin = time.time()
missing_graphemes = set()
if self.rewriter is None:
self.setup()
logger.info("Generating pronunciations...")
to_return = {}
skipped_words = 0
if num_words < 30 or GLOBAL_CONFIG.num_jobs == 1:
with tqdm.tqdm(total=num_words, disable=GLOBAL_CONFIG.quiet) as pbar:
for word in self.words_to_g2p:
w, m = clean_up_word(word, self.g2p_model.meta["graphemes"])
pbar.update(1)
missing_graphemes = missing_graphemes | m
if self.strict_graphemes and m:
skipped_words += 1
continue
if not w:
skipped_words += 1
continue
try:
prons = self.rewriter(w)
except rewrite.Error:
continue
to_return[word] = prons
logger.debug(
f"Skipping {skipped_words} words for containing the following graphemes: "
f"{comma_join(sorted(missing_graphemes))}"
)
else:
stopped = Stopped()
job_queue = mp.Queue()
for word in self.words_to_g2p:
w, m = clean_up_word(word, self.g2p_model.meta["graphemes"])
missing_graphemes = missing_graphemes | m
if self.strict_graphemes and m:
skipped_words += 1
continue
if not w:
skipped_words += 1
continue
job_queue.put(w)
logger.debug(
f"Skipping {skipped_words} words for containing the following graphemes: "
f"{comma_join(sorted(missing_graphemes))}"
)
error_dict = {}
return_queue = mp.Queue()
procs = []
for _ in range(GLOBAL_CONFIG.num_jobs):
p = RewriterWorker(
job_queue,
return_queue,
self.rewriter,
stopped,
)
procs.append(p)
p.start()
num_words -= skipped_words
with tqdm.tqdm(total=num_words, disable=GLOBAL_CONFIG.quiet) as pbar:
while True:
try:
word, result = return_queue.get(timeout=1)
if stopped.stop_check():
continue
except queue.Empty:
for proc in procs:
if not proc.finished.stop_check():
break
else:
break
continue
pbar.update(1)
if isinstance(result, Exception):
error_dict[word] = result
continue
to_return[word] = result
for p in procs:
p.join()
if error_dict:
raise PyniniGenerationError(error_dict)
logger.debug(f"Processed {num_words} in {time.time() - begin:.3f} seconds")
return to_return
[docs]
class PyniniValidator(PyniniGenerator, TopLevelMfaWorker):
"""
Class for running validation for G2P model training
Parameters
----------
word_list: list[str]
List of words to generate pronunciations
See Also
--------
:class:`~montreal_forced_aligner.g2p.generator.PyniniGenerator`
For parameters to generate pronunciations
"""
def __init__(self, word_list: List[str] = None, **kwargs):
super().__init__(**kwargs)
if word_list is None:
word_list = []
self.word_list = word_list
@property
def words_to_g2p(self) -> List[str]:
"""Words to produce pronunciations"""
return self.word_list
@property
def data_source_identifier(self) -> str:
"""Dummy "validation" data source"""
return "validation"
@property
def data_directory(self) -> str:
"""Data directory"""
return self.working_directory
@property
def evaluation_csv_path(self) -> str:
"""Path to working directory's CSV file"""
return os.path.join(self.working_directory, "pronunciation_evaluation.csv")
[docs]
def setup(self) -> None:
"""Set up the G2P validator"""
TopLevelMfaWorker.setup(self)
if self.initialized:
return
self._current_workflow = "validation"
os.makedirs(self.working_log_directory, exist_ok=True)
self.g2p_model.validate(self.words_to_g2p)
PyniniGenerator.setup(self)
self.initialized = True
self.wer = None
self.ler = None
[docs]
def compute_validation_errors(
self,
gold_values: Dict[str, Set[str]],
hypothesis_values: Dict[str, List[str]],
):
"""
Computes validation errors
Parameters
----------
gold_values: dict[str, set[str]]
Gold pronunciations
hypothesis_values: dict[str, list[str]]
Hypothesis pronunciations
"""
begin = time.time()
# Word-level measures.
correct = 0
incorrect = 0
# Label-level measures.
total_edits = 0
total_length = 0
# Since the edit distance algorithm is quadratic, let's do this with
# multiprocessing.
logger.debug(f"Processing results for {len(hypothesis_values)} hypotheses")
to_comp = []
indices = []
hyp_pron_count = 0
gold_pron_count = 0
output = []
for word, gold_pronunciations in gold_values.items():
if word not in hypothesis_values:
incorrect += 1
gold_length = statistics.mean(len(x.split()) for x in gold_pronunciations)
total_edits += gold_length
total_length += gold_length
output.append(
{
"Word": word,
"Gold pronunciations": ", ".join(gold_pronunciations),
"Hypothesis pronunciations": "",
"Accuracy": 0,
"Error rate": 1.0,
"Length": gold_length,
}
)
continue
hyp = hypothesis_values[word]
for h in hyp:
if h in gold_pronunciations:
correct += 1
total_length += len(h)
output.append(
{
"Word": word,
"Gold pronunciations": ", ".join(gold_pronunciations),
"Hypothesis pronunciations": ", ".join(hyp),
"Accuracy": 1,
"Error rate": 0.0,
"Length": len(h),
}
)
break
else:
incorrect += 1
indices.append(word)
to_comp.append((gold_pronunciations, hyp)) # Multiple hypotheses to compare
logger.debug(
f"For the word {word}: gold is {gold_pronunciations}, hypothesized are: {hyp}"
)
hyp_pron_count += len(hyp)
gold_pron_count += len(gold_pronunciations)
logger.debug(
f"Generated an average of {hyp_pron_count /len(hypothesis_values)} variants "
f"The gold set had an average of {gold_pron_count/len(hypothesis_values)} variants."
)
with mp.Pool(GLOBAL_CONFIG.num_jobs) as pool:
gen = pool.starmap(score_g2p, to_comp)
for i, (edits, length) in enumerate(gen):
word = indices[i]
gold_pronunciations = gold_values[word]
hyp = hypothesis_values[word]
output.append(
{
"Word": word,
"Gold pronunciations": ", ".join(gold_pronunciations),
"Hypothesis pronunciations": ", ".join(hyp),
"Accuracy": 1,
"Error rate": edits / length,
"Length": length,
}
)
total_edits += edits
total_length += length
with mfa_open(self.evaluation_csv_path, "w") as f:
writer = csv.DictWriter(
f,
fieldnames=[
"Word",
"Gold pronunciations",
"Hypothesis pronunciations",
"Accuracy",
"Error rate",
"Length",
],
)
writer.writeheader()
for line in output:
writer.writerow(line)
self.wer = 100 * incorrect / (correct + incorrect)
self.ler = 100 * total_edits / total_length
logger.info(f"WER:\t{self.wer:.2f}")
logger.info(f"LER:\t{self.ler:.2f}")
logger.debug(
f"Computation of errors for {len(gold_values)} words took {time.time() - begin:.3f} seconds"
)
[docs]
def evaluate_g2p_model(self, gold_pronunciations: Dict[str, Set[str]]) -> None:
"""
Evaluate a G2P model on the word list
Parameters
----------
gold_pronunciations: dict[str, set[str]]
Gold pronunciations
"""
output = self.generate_pronunciations()
self.compute_validation_errors(gold_pronunciations, output)
[docs]
class PyniniWordListGenerator(PyniniValidator, DatabaseMixin):
"""
Top-level worker for generating pronunciations from a word list and a Pynini G2P model
Parameters
----------
word_list_path: str
Path to word list file
See Also
--------
:class:`~montreal_forced_aligner.g2p.generator.PyniniGenerator`
For Pynini G2P generation parameters
:class:`~montreal_forced_aligner.abc.TopLevelMfaWorker`
For top-level parameters
Attributes
----------
word_list: list[str]
Word list to generate pronunciations
"""
def __init__(self, word_list_path: str, **kwargs):
self.word_list_path = word_list_path
super().__init__(**kwargs)
@property
def data_directory(self) -> str:
"""Data directory"""
return self.working_directory
@property
def data_source_identifier(self) -> str:
"""Name of the word list file"""
return os.path.splitext(os.path.basename(self.word_list_path))[0]
[docs]
def setup(self) -> None:
"""Set up the G2P generator"""
if self.initialized:
return
with mfa_open(self.word_list_path, "r") as f:
for line in f:
self.word_list.extend(line.strip().split())
if not self.include_bracketed:
self.word_list = [x for x in self.word_list if not self.check_bracketed(x)]
super().setup()
self.g2p_model.validate(self.words_to_g2p)
self.initialized = True
[docs]
class PyniniCorpusGenerator(PyniniGenerator, TextCorpusMixin, TopLevelMfaWorker):
"""
Top-level worker for generating pronunciations from a corpus and a Pynini G2P model
See Also
--------
:class:`~montreal_forced_aligner.g2p.generator.PyniniGenerator`
For Pynini G2P generation parameters
:class:`~montreal_forced_aligner.corpus.text_corpus.TextCorpusMixin`
For corpus parsing parameters
:class:`~montreal_forced_aligner.abc.TopLevelMfaWorker`
For top-level parameters
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
[docs]
def setup(self) -> None:
"""Set up the pronunciation generator"""
if self.initialized:
return
self._load_corpus()
self.initialize_jobs()
super().setup()
self._create_dummy_dictionary()
self.normalize_text()
self.g2p_model.validate(self.words_to_g2p)
self.initialized = True
@property
def words_to_g2p(self) -> List[str]:
"""Words to produce pronunciations"""
word_list = self.corpus_word_set
if not self.include_bracketed:
word_list = [x for x in word_list if not self.check_bracketed(x)]
return word_list