"""Class definitions for PronunciationProbabilityTrainer"""
import json
import logging
import os
import re
import shutil
import time
import typing
import tqdm
from sqlalchemy.orm import joinedload
from montreal_forced_aligner.acoustic_modeling.base import AcousticModelTrainingMixin
from montreal_forced_aligner.alignment.multiprocessing import (
GeneratePronunciationsArguments,
GeneratePronunciationsFunction,
)
from montreal_forced_aligner.config import GLOBAL_CONFIG
from montreal_forced_aligner.db import CorpusWorkflow, Dictionary, Pronunciation, Utterance, Word
from montreal_forced_aligner.g2p.trainer import PyniniTrainerMixin
from montreal_forced_aligner.helper import mfa_open
from montreal_forced_aligner.utils import parse_dictionary_file, run_kaldi_function
__all__ = ["PronunciationProbabilityTrainer"]
logger = logging.getLogger("mfa")
[docs]
class PronunciationProbabilityTrainer(AcousticModelTrainingMixin, PyniniTrainerMixin):
"""
Class for training pronunciation probabilities based off of alignment pronunciations
Parameters
----------
previous_trainer: AcousticModelTrainingMixin
Previous trainer in the training configuration
silence_probabilities: bool
Flag for whether to save silence probabilities
"""
def __init__(
self,
previous_trainer: typing.Optional[AcousticModelTrainingMixin] = None,
silence_probabilities: bool = True,
train_g2p: bool = False,
use_phonetisaurus: bool = False,
num_iterations: int = 10,
model_size: int = 100000,
**kwargs,
):
self.previous_trainer = previous_trainer
self.silence_probabilities = silence_probabilities
self.train_g2p = train_g2p
self.use_phonetisaurus = use_phonetisaurus
super(PronunciationProbabilityTrainer, self).__init__(
num_iterations=num_iterations, model_size=model_size, **kwargs
)
self.subset = self.previous_trainer.subset
self.pronunciations_complete = False
@property
def train_type(self) -> str:
"""Training type"""
return "pronunciation_probabilities"
[docs]
def compute_calculated_properties(self) -> None:
"""Compute calculated properties"""
pass
def _trainer_initialization(self) -> None:
"""Initialize trainer"""
pass
@property
def exported_model_path(self) -> str:
"""Path to exported acoustic model"""
return self.previous_trainer.exported_model_path
@property
def model_path(self) -> str:
"""Current acoustic model path"""
return os.path.join(self.working_directory, "final.mdl")
@property
def alignment_model_path(self) -> str:
"""Alignment model path"""
path = self.model_path.replace(".mdl", ".alimdl")
if os.path.exists(path):
return path
return self.model_path
@property
def phone_symbol_table_path(self) -> str:
"""Worker's phone symbol table"""
return self.worker.phone_symbol_table_path
@property
def grapheme_symbol_table_path(self) -> str:
"""Worker's grapheme symbol table"""
return self.worker.grapheme_symbol_table_path
@property
def input_path(self) -> str:
"""Path to temporary file to store training data"""
return os.path.join(self.working_directory, f"input_{self._data_source}.txt")
@property
def output_path(self) -> str:
"""Path to temporary file to store training data"""
return os.path.join(self.working_directory, f"output_{self._data_source}.txt")
@property
def output_alignment_path(self) -> str:
"""Path to temporary file to store training data"""
return os.path.join(self.working_directory, f"output_{self._data_source}_alignment.txt")
[docs]
def generate_pronunciations_arguments(self) -> typing.List[GeneratePronunciationsArguments]:
"""
Generate Job arguments for :func:`~montreal_forced_aligner.alignment.multiprocessing.GeneratePronunciationsFunction`
Returns
-------
list[:class:`~montreal_forced_aligner.alignment.multiprocessing.GeneratePronunciationsArguments`]
Arguments for processing
"""
return [
GeneratePronunciationsArguments(
j.id,
getattr(self, "db_string", ""),
os.path.join(self.working_log_directory, f"generate_pronunciations.{j.id}.log"),
self.model_path,
True,
)
for j in self.jobs
]
[docs]
def align_g2p(self, output_path=None) -> None:
"""Runs the entire alignment regimen."""
self._lexicon_covering(output_path=output_path)
self._alignments()
self._encode()
[docs]
def train_g2p_lexicon(self) -> None:
"""Generate a G2P lexicon based on aligned transcripts"""
arguments = self.generate_pronunciations_arguments()
working_dir = super(PronunciationProbabilityTrainer, self).working_directory
texts = {}
with self.worker.session() as session:
query = session.query(Utterance.id, Utterance.normalized_character_text)
query = query.filter(Utterance.ignored == False) # noqa
# query = query.filter(Utterance.oovs != '', Utterance.oovs != None)
if self.subset:
query = query.filter_by(in_subset=True)
for utt_id, text in query:
texts[utt_id] = text
input_files = {
x: open(
os.path.join(working_dir, f"input_{self.worker.dictionary_base_names[x]}.txt"),
"w",
encoding="utf8",
newline="",
)
for x in self.worker.dictionary_lookup.values()
}
output_files = {
x: open(
os.path.join(
working_dir, f"output_{self.worker.dictionary_base_names[x]}.txt"
),
"w",
encoding="utf8",
newline="",
)
for x in self.worker.dictionary_lookup.values()
}
output_alignment_files = {
x: open(
os.path.join(
working_dir, f"output_{self.worker.dictionary_base_names[x]}_alignment.txt"
),
"w",
encoding="utf8",
newline="",
)
for x in self.worker.dictionary_lookup.values()
}
with tqdm.tqdm(total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet) as pbar:
for dict_id, utt_id, phones in run_kaldi_function(
GeneratePronunciationsFunction, arguments, pbar.update
):
if utt_id not in texts or not texts[utt_id]:
continue
print(phones, file=output_alignment_files[dict_id])
print(
re.sub(r"\s+", " ", phones.replace("#1", "").replace("#2", "")).strip(),
file=output_files[dict_id],
)
print(texts[utt_id], file=input_files[dict_id])
for f in input_files.values():
f.close()
for f in output_files.values():
f.close()
for f in output_alignment_files.values():
f.close()
self.pronunciations_complete = True
os.makedirs(self.working_log_directory, exist_ok=True)
dictionaries = session.query(Dictionary)
shutil.copyfile(
self.phone_symbol_table_path, os.path.join(self.working_directory, "phones.txt")
)
shutil.copyfile(
self.grapheme_symbol_table_path,
os.path.join(self.working_directory, "graphemes.txt"),
)
self.input_token_type = self.grapheme_symbol_table_path
self.output_token_type = self.phone_symbol_table_path
for d in dictionaries:
logger.info(f"Training G2P for {d.name}...")
self._data_source = self.worker.dictionary_base_names[d.id]
begin = time.time()
if os.path.exists(self.far_path) and os.path.exists(self.encoder_path):
logger.info("Alignment already done, skipping!")
else:
self.align_g2p()
logger.debug(
f"Aligning utterances for {d.name} took {time.time() - begin:.3f} seconds"
)
begin = time.time()
self.generate_model()
logger.debug(
f"Generating model for {d.name} took {time.time() - begin:.3f} seconds"
)
os.rename(d.lexicon_fst_path, d.lexicon_fst_path + ".backup")
os.rename(self.fst_path, d.lexicon_fst_path)
if not GLOBAL_CONFIG.current_profile.debug:
os.remove(self.output_path)
os.remove(self.input_far_path)
os.remove(self.output_far_path)
for f in os.listdir(self.working_directory):
if any(f.endswith(x) for x in [".fst", ".like", ".far", ".enc"]):
os.remove(os.path.join(self.working_directory, f))
begin = time.time()
self.align_g2p(self.output_alignment_path)
logger.debug(
f"Aligning utterances for {d.name} took {time.time() - begin:.3f} seconds"
)
begin = time.time()
self.generate_model()
logger.debug(
f"Generating model for {d.name} took {time.time() - begin:.3f} seconds"
)
os.rename(d.align_lexicon_path, d.align_lexicon_path + ".backup")
os.rename(self.fst_path, d.align_lexicon_path)
if not GLOBAL_CONFIG.current_profile.debug:
os.remove(self.output_alignment_path)
os.remove(self.input_path)
os.remove(self.input_far_path)
os.remove(self.output_far_path)
for f in os.listdir(self.working_directory):
if any(f.endswith(x) for x in [".fst", ".like", ".far", ".enc"]):
os.remove(os.path.join(self.working_directory, f))
d.use_g2p = True
session.commit()
self.worker.use_g2p = True
[docs]
def export_model(self, output_model_path: str) -> None:
"""
Export an acoustic model to the specified path
Parameters
----------
output_model_path : str
Path to save acoustic model
"""
AcousticModelTrainingMixin.export_model(self, output_model_path)
def setup(self):
wf = self.worker.current_workflow
previous_directory = self.previous_aligner.working_directory
for j in self.jobs:
for p in j.construct_path_dictionary(previous_directory, "ali", "ark").values():
shutil.copy(p, p.replace(previous_directory, wf.working_directory))
for f in ["final.mdl", "final.alimdl", "final.occs", "lda.mat"]:
p = os.path.join(previous_directory, f)
if os.path.exists(p):
shutil.copy(p, p.replace(previous_directory, wf.working_directory))
[docs]
def train_pronunciation_probabilities(self) -> None:
"""
Train pronunciation probabilities based on previous alignment
"""
wf = self.worker.current_workflow
os.makedirs(os.path.join(wf.working_directory, "log"), exist_ok=True)
if wf.done:
logger.info(
"Pronunciation probability estimation already done, loading saved probabilities..."
)
self.training_complete = True
if self.train_g2p:
self.pronunciations_complete = True
with self.worker.session() as session:
dictionaries = session.query(Dictionary).all()
for d in dictionaries:
fst_path = os.path.join(
self.working_directory,
f"{self.worker.dictionary_base_names[d.id]}.fst",
)
os.rename(d.lexicon_fst_path, d.lexicon_fst_path + ".backup")
shutil.copy(fst_path, d.lexicon_fst_path)
d.use_g2p = True
session.commit()
self.worker.use_g2p = True
return
silence_prob_sum = 0
initial_silence_prob_sum = 0
final_silence_correction_sum = 0
final_non_silence_correction_sum = 0
with self.worker.session() as session:
dictionaries = session.query(Dictionary).all()
for d in dictionaries:
pronunciations = (
session.query(Pronunciation)
.join(Pronunciation.word)
.options(joinedload(Pronunciation.word, innerjoin=True))
.filter(Word.dictionary_id == d.id)
)
cache = {(x.word.word, x.pronunciation): x for x in pronunciations}
new_dictionary_path = os.path.join(self.working_directory, f"{d.id}.dict")
for (
word,
pron,
prob,
silence_after_prob,
silence_before_correct,
non_silence_before_correct,
) in parse_dictionary_file(new_dictionary_path):
if (word, " ".join(pron)) not in cache:
continue
p = cache[(word, " ".join(pron))]
p.probability = prob
p.silence_after_probability = silence_after_prob
p.silence_before_correction = silence_before_correct
p.non_silence_before_correction = non_silence_before_correct
silence_info_path = os.path.join(
self.working_directory, f"{d.id}_silence_info.json"
)
with mfa_open(silence_info_path, "r") as f:
data = json.load(f)
if self.silence_probabilities:
d.silence_probability = data["silence_probability"]
d.initial_silence_probability = data["initial_silence_probability"]
d.final_silence_correction = data["final_silence_correction"]
d.final_non_silence_correction = data["final_non_silence_correction"]
silence_prob_sum += d.silence_probability
initial_silence_prob_sum += d.initial_silence_probability
final_silence_correction_sum += d.final_silence_correction
final_non_silence_correction_sum += d.final_non_silence_correction
if self.silence_probabilities:
self.worker.silence_probability = silence_prob_sum / len(dictionaries)
self.worker.initial_silence_probability = initial_silence_prob_sum / len(
dictionaries
)
self.worker.final_silence_correction = final_silence_correction_sum / len(
dictionaries
)
self.worker.final_non_silence_correction = (
final_non_silence_correction_sum / len(dictionaries)
)
session.commit()
self.worker.write_lexicon_information()
return
self.setup()
if self.train_g2p:
self.train_g2p_lexicon()
else:
os.makedirs(self.working_log_directory, exist_ok=True)
self.worker.compute_pronunciation_probabilities()
self.worker.write_lexicon_information()
with self.worker.session() as session:
for d in session.query(Dictionary):
dict_path = os.path.join(self.working_directory, f"{d.id}.dict")
self.worker.export_trained_rules(self.working_directory)
self.worker.export_lexicon(
d.id,
dict_path,
probability=True,
)
silence_info_path = os.path.join(
self.working_directory, f"{d.id}_silence_info.json"
)
with mfa_open(silence_info_path, "w") as f:
json.dump(d.silence_probability_info, f)
with self.session() as session:
session.query(CorpusWorkflow).filter(CorpusWorkflow.id == wf.id).update({"done": True})
session.commit()
[docs]
def train_iteration(self) -> None:
"""Training iteration"""
pass