"""Classes for tokenizers"""
import csv
import functools
import logging
import os
import queue
import threading
import time
import typing
from multiprocessing.pool import ThreadPool
from pathlib import Path
from queue import Queue
import pynini
import pywrapfst
from praatio import textgrid
from pynini import Fst
from pynini.lib import rewrite
from pywrapfst import SymbolTable
from sqlalchemy.orm import joinedload, selectinload
from tqdm.rich import tqdm
from montreal_forced_aligner import config
from montreal_forced_aligner.abc import KaldiFunction, TopLevelMfaWorker
from montreal_forced_aligner.corpus.acoustic_corpus import AcousticCorpusMixin
from montreal_forced_aligner.data import MfaArguments, TextgridFormats
from montreal_forced_aligner.db import File, Utterance, bulk_update
from montreal_forced_aligner.dictionary.mixins import DictionaryMixin
from montreal_forced_aligner.exceptions import PyniniGenerationError
from montreal_forced_aligner.g2p.generator import PhonetisaurusRewriter, Rewriter, RewriterWorker
from montreal_forced_aligner.helper import edit_distance, mfa_open
from montreal_forced_aligner.models import TokenizerModel
from montreal_forced_aligner.textgrid import construct_output_path
from montreal_forced_aligner.utils import run_kaldi_function
if typing.TYPE_CHECKING:
from dataclasses import dataclass
else:
from dataclassy import dataclass
__all__ = [
"TokenizerRewriter",
"TokenizerArguments",
"TokenizerFunction",
"TokenizerValidator",
"CorpusTokenizer",
]
logger = logging.getLogger("mfa")
[docs]
class TokenizerRewriter(Rewriter):
"""
Helper object for rewriting
Parameters
----------
fst: pynini.Fst
Tokenizer FST model
grapheme_symbols: pynini.SymbolTable
Grapheme symbol table
"""
def __init__(
self,
fst: Fst,
grapheme_symbols: SymbolTable,
):
self.grapheme_symbols = grapheme_symbols
self.rewrite = functools.partial(
rewrite.top_rewrite,
rule=fst,
input_token_type=grapheme_symbols,
output_token_type=grapheme_symbols,
)
def __call__(self, i: str) -> str: # pragma: no cover
"""Call the rewrite function"""
i = i.replace(" ", "")
original = list(i)
unks = []
normalized = []
for c in original:
if self.grapheme_symbols.member(c):
normalized.append(c)
else:
unks.append(c)
normalized.append("<unk>")
hypothesis = self.rewrite(" ".join(normalized)).split()
unk_index = 0
for i, w in enumerate(hypothesis):
if w == "<unk>":
hypothesis[i] = unks[unk_index]
unk_index += 1
elif w == "<space>":
hypothesis[i] = " "
return "".join(hypothesis)
class TokenizerPhonetisaurusRewriter(PhonetisaurusRewriter):
"""
Helper function for rewriting
Parameters
----------
fst: pynini.Fst
G2P FST model
input_token_type: pynini.SymbolTable
Grapheme symbol table
output_token_type: pynini.SymbolTable
num_pronunciations: int
Number of pronunciations, default to 0. If this is 0, thresholding is used
threshold: float
Threshold to use for pruning rewrite lattice, defaults to 1.5, only used if num_pronunciations is 0
grapheme_order: int
Maximum number of graphemes to consider single segment
seq_sep: str
Separator to use between grapheme symbols
"""
def __init__(
self,
fst: Fst,
input_token_type: SymbolTable,
output_token_type: SymbolTable,
input_order: int = 2,
seq_sep: str = "|",
):
self.fst = fst
self.seq_sep = seq_sep
self.input_token_type = input_token_type
self.output_token_type = output_token_type
self.input_order = input_order
self.rewrite = functools.partial(
rewrite.top_rewrite,
rule=fst,
input_token_type=None,
output_token_type=output_token_type,
)
def __call__(self, graphemes: str) -> str: # pragma: no cover
"""Call the rewrite function"""
graphemes = graphemes.replace(" ", "")
original = list(graphemes)
unks = []
normalized = []
for c in original:
if self.output_token_type.member(c):
normalized.append(c)
else:
unks.append(c)
normalized.append("<unk>")
fst = pynini.Fst()
one = pynini.Weight.one(fst.weight_type())
max_state = 0
for i in range(len(normalized)):
start_state = fst.add_state()
for j in range(1, self.input_order + 1):
if i + j <= len(normalized):
substring = self.seq_sep.join(normalized[i : i + j])
ilabel = self.input_token_type.find(substring)
if ilabel != pynini.NO_LABEL:
fst.add_arc(start_state, pynini.Arc(ilabel, ilabel, one, i + j))
if i + j >= max_state:
max_state = i + j
for _ in range(fst.num_states(), max_state + 1):
fst.add_state()
fst.set_start(0)
fst.set_final(len(normalized), one)
fst.set_input_symbols(self.input_token_type)
fst.set_output_symbols(self.input_token_type)
hypothesis = self.rewrite(fst).split()
unk_index = 0
output = []
for i, w in enumerate(hypothesis):
if w == "<unk>":
output.append(unks[unk_index])
unk_index += 1
elif w == "<space>":
if i > 0 and hypothesis[i - 1] == " ":
continue
output.append(" ")
else:
output.append(w)
return "".join(output).strip()
[docs]
@dataclass
class TokenizerArguments(MfaArguments):
rewriter: Rewriter
[docs]
class TokenizerFunction(KaldiFunction):
def __init__(self, args: TokenizerArguments):
super().__init__(args)
self.rewriter = args.rewriter
def _run(self) -> None:
"""Run the function"""
with self.session() as session:
utterances = session.query(Utterance.id, Utterance.normalized_text).filter(
Utterance.job_id == self.job_name
)
for u_id, text in utterances:
tokenized_text = self.rewriter(text)
self.callback((u_id, tokenized_text))
[docs]
class CorpusTokenizer(AcousticCorpusMixin, TopLevelMfaWorker, DictionaryMixin):
"""
Top-level worker for generating pronunciations from a corpus and a Pynini tokenizer model
"""
model_class = TokenizerModel
def __init__(self, tokenizer_model_path: Path = None, **kwargs):
super().__init__(**kwargs)
self.tokenizer_model = TokenizerModel(
tokenizer_model_path, root_directory=getattr(self, "workflow_directory", None)
)
[docs]
def setup(self) -> None:
"""Set up the pronunciation generator"""
if self.initialized:
return
self._load_corpus()
self.initialize_jobs()
super().setup()
self._create_dummy_dictionary()
self.normalize_text()
self.fst = pynini.Fst.read(self.tokenizer_model.fst_path)
if self.tokenizer_model.meta["architecture"] == "phonetisaurus":
self.output_token_type = pywrapfst.SymbolTable.read_text(
self.tokenizer_model.output_sym_path
)
self.input_token_type = pywrapfst.SymbolTable.read_text(
self.tokenizer_model.input_sym_path
)
self.rewriter = TokenizerPhonetisaurusRewriter(
self.fst,
self.input_token_type,
self.output_token_type,
input_order=self.tokenizer_model.meta["input_order"],
)
else:
self.grapheme_symbols = pywrapfst.SymbolTable.read_text(self.tokenizer_model.sym_path)
self.rewriter = TokenizerRewriter(
self.fst,
self.grapheme_symbols,
)
self.initialized = True
[docs]
def export_files(self, output_directory: Path) -> None:
"""Export transcriptions"""
with self.session() as session:
files = session.query(File).options(
selectinload(File.utterances),
selectinload(File.speakers),
joinedload(File.sound_file),
)
for file in files:
utterance_count = len(file.utterances)
if file.sound_file is not None:
duration = file.sound_file.duration
else:
duration = max([u.end for u in file.utterances])
if utterance_count == 0:
logger.debug(f"Could not find any utterances for {file.name}")
elif (
utterance_count == 1
and file.utterances[0].begin == 0
and file.utterances[0].end == duration
):
output_format = "lab"
else:
output_format = TextgridFormats.SHORT_TEXTGRID
output_path = construct_output_path(
file.name,
file.relative_path,
output_directory,
output_format=output_format,
)
data = file.construct_transcription_tiers(original_text=True)
if output_format == "lab":
for intervals in data.values():
with mfa_open(output_path, "w") as f:
f.write(intervals["text"][0].label)
else:
tg = textgrid.Textgrid()
tg.minTimestamp = 0
tg.maxTimestamp = round(duration, 5)
for speaker in file.speakers:
speaker = speaker.name
intervals = data[speaker]["text"]
tier = textgrid.IntervalTier(
speaker,
[x.to_tg_interval() for x in intervals],
minT=0,
maxT=round(duration, 5),
)
tg.addTier(tier)
tg.save(output_path, includeBlankSpaces=True, format=output_format)
def tokenize_arguments(self) -> typing.List[TokenizerArguments]:
return [
TokenizerArguments(
j.id,
getattr(self, "session" if config.USE_THREADING else "db_string", ""),
None,
self.rewriter,
)
for j in self.jobs
]
[docs]
def tokenize_utterances(self) -> None:
"""
Tokenize utterances
Returns
-------
dict[str, list[str]]
Mappings of keys to their tokenized utterances
"""
begin = time.time()
if not self.initialized:
self.setup()
logger.info("Tokenizing utterances...")
args = self.tokenize_arguments()
update_mapping = []
for utt_id, tokenized in run_kaldi_function(
TokenizerFunction, args, total_count=self.num_utterances
):
update_mapping.append({"id": utt_id, "text": tokenized})
with self.session() as session:
bulk_update(session, Utterance, update_mapping)
session.commit()
logger.debug(f"Tokenizing utterances took {time.time() - begin:.3f} seconds")
[docs]
class TokenizerValidator(CorpusTokenizer):
def __init__(self, utterances_to_tokenize: typing.List[str] = None, **kwargs):
super().__init__(**kwargs)
if utterances_to_tokenize is None:
utterances_to_tokenize = []
self.utterances_to_tokenize = utterances_to_tokenize
self.uer = None
self.cer = None
[docs]
def setup(self):
TopLevelMfaWorker.setup(self)
if self.initialized:
return
self._current_workflow = "validation"
os.makedirs(self.working_log_directory, exist_ok=True)
self.fst = pynini.Fst.read(self.tokenizer_model.fst_path)
if self.tokenizer_model.meta["architecture"] == "phonetisaurus":
self.output_token_type = pywrapfst.SymbolTable.read_text(
self.tokenizer_model.output_sym_path
)
self.input_token_type = pywrapfst.SymbolTable.read_text(
self.tokenizer_model.input_sym_path
)
self.rewriter = TokenizerPhonetisaurusRewriter(
self.fst,
self.input_token_type,
self.output_token_type,
input_order=self.tokenizer_model.meta["input_order"],
)
else:
self.grapheme_symbols = pywrapfst.SymbolTable.read_text(self.tokenizer_model.sym_path)
self.rewriter = TokenizerRewriter(
self.fst,
self.grapheme_symbols,
)
self.initialized = True
[docs]
def tokenize_utterances(self) -> typing.Dict[str, str]:
"""
Tokenize utterances
Returns
-------
dict[str, list[str]]
Mappings of keys to their tokenized utterances
"""
num_utterances = len(self.utterances_to_tokenize)
begin = time.time()
if not self.initialized:
self.setup()
logger.info("Tokenizing utterances...")
to_return = {}
if num_utterances < 30 or config.NUM_JOBS == 1:
with tqdm(total=num_utterances, disable=config.QUIET) as pbar:
for utterance in self.utterances_to_tokenize:
pbar.update(1)
result = self.rewriter(utterance)
to_return[utterance] = result
else:
stopped = threading.Event()
job_queue = Queue()
for utterance in self.utterances_to_tokenize:
job_queue.put(utterance)
error_dict = {}
return_queue = Queue()
procs = []
for _ in range(config.NUM_JOBS):
p = RewriterWorker(
job_queue,
return_queue,
self.rewriter,
stopped,
)
procs.append(p)
p.start()
with tqdm(total=num_utterances, disable=config.QUIET) as pbar:
while True:
try:
utterance, result = return_queue.get(timeout=1)
if stopped.is_set():
continue
return_queue.task_done()
except queue.Empty:
for proc in procs:
if not proc.finished.is_set():
break
else:
break
continue
pbar.update(1)
if isinstance(result, Exception):
error_dict[utterance] = result
continue
to_return[utterance] = result
for p in procs:
p.join()
if error_dict:
raise PyniniGenerationError(error_dict)
logger.debug(f"Processed {num_utterances} in {time.time() - begin:.3f} seconds")
return to_return
@property
def data_source_identifier(self) -> str:
"""Dummy "validation" data source"""
return "validation"
@property
def data_directory(self) -> Path:
"""Data directory"""
return self.working_directory
@property
def evaluation_csv_path(self) -> Path:
"""Path to working directory's CSV file"""
return self.working_directory.joinpath("pronunciation_evaluation.csv")
[docs]
def compute_validation_errors(
self,
gold_values: typing.Dict[str, str],
hypothesis_values: typing.Dict[str, str],
):
"""
Computes validation errors
Parameters
----------
gold_values: dict[str, set[str]]
Gold pronunciations
hypothesis_values: dict[str, list[str]]
Hypothesis pronunciations
"""
begin = time.time()
# Word-level measures.
correct = 0
incorrect = 0
# Label-level measures.
total_edits = 0
total_length = 0
# Since the edit distance algorithm is quadratic, let's do this with
# multiprocessing.
logger.debug(f"Processing results for {len(hypothesis_values)} hypotheses")
to_comp = []
indices = []
output = []
for word, gold in gold_values.items():
if word not in hypothesis_values:
incorrect += 1
gold_length = len(gold)
total_edits += gold_length
total_length += gold_length
output.append(
{
"Word": word,
"Gold tokenization": gold,
"Hypothesis tokenization": "",
"Accuracy": 0,
"Error rate": 1.0,
"Length": gold_length,
}
)
continue
hyp = hypothesis_values[word]
if hyp == gold:
correct += 1
total_length += len(hyp)
output.append(
{
"Word": word,
"Gold tokenization": gold,
"Hypothesis tokenization": hyp,
"Accuracy": 1,
"Error rate": 0.0,
"Length": len(hyp),
}
)
else:
incorrect += 1
indices.append(word)
to_comp.append((gold, hyp)) # Multiple hypotheses to compare
with ThreadPool(config.NUM_JOBS) as pool:
gen = pool.starmap(edit_distance, to_comp)
for i, (edits) in enumerate(gen):
word = indices[i]
gold = gold_values[word]
length = len(gold)
hyp = hypothesis_values[word]
output.append(
{
"Word": word,
"Gold tokenization": gold,
"Hypothesis tokenization": hyp,
"Accuracy": 1,
"Error rate": edits / length,
"Length": length,
}
)
total_edits += edits
total_length += length
with mfa_open(self.evaluation_csv_path, "w") as f:
writer = csv.DictWriter(
f,
fieldnames=[
"Word",
"Gold tokenization",
"Hypothesis tokenization",
"Accuracy",
"Error rate",
"Length",
],
)
writer.writeheader()
for line in output:
writer.writerow(line)
self.uer = 100 * incorrect / (correct + incorrect)
self.cer = 100 * total_edits / total_length
logger.info(f"UER:\t{self.uer:.2f}")
logger.info(f"CER:\t{self.cer:.2f}")
logger.debug(
f"Computation of errors for {len(gold_values)} utterances took {time.time() - begin:.3f} seconds"
)