"""
Transcription
=============
"""
from __future__ import annotations
import collections
import csv
import logging
import multiprocessing as mp
import os
import shutil
import subprocess
import time
import typing
from pathlib import Path
from queue import Empty
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import pywrapfst
from praatio import textgrid
from sqlalchemy.orm import joinedload, selectinload
from tqdm.rich import tqdm
from montreal_forced_aligner.abc import TopLevelMfaWorker
from montreal_forced_aligner.alignment.base import CorpusAligner
from montreal_forced_aligner.config import GLOBAL_CONFIG
from montreal_forced_aligner.data import (
ArpaNgramModel,
TextFileType,
TextgridFormats,
WorkflowType,
)
from montreal_forced_aligner.db import (
CorpusWorkflow,
Dictionary,
File,
Phone,
SoundFile,
Speaker,
Utterance,
bulk_update,
)
from montreal_forced_aligner.exceptions import KaldiProcessingError
from montreal_forced_aligner.helper import (
load_configuration,
mfa_open,
parse_old_features,
score_wer,
)
from montreal_forced_aligner.language_modeling.multiprocessing import (
TrainLmArguments,
TrainPhoneLmFunction,
TrainSpeakerLmArguments,
TrainSpeakerLmFunction,
)
from montreal_forced_aligner.models import AcousticModel, LanguageModel
from montreal_forced_aligner.textgrid import construct_output_path
from montreal_forced_aligner.transcription.multiprocessing import (
CarpaLmRescoreArguments,
CarpaLmRescoreFunction,
CreateHclgArguments,
CreateHclgFunction,
DecodeArguments,
DecodeFunction,
DecodePhoneArguments,
DecodePhoneFunction,
FinalFmllrArguments,
FinalFmllrFunction,
FmllrRescoreArguments,
FmllrRescoreFunction,
InitialFmllrArguments,
InitialFmllrFunction,
LatGenFmllrArguments,
LatGenFmllrFunction,
LmRescoreArguments,
LmRescoreFunction,
PerSpeakerDecodeArguments,
PerSpeakerDecodeFunction,
)
from montreal_forced_aligner.utils import (
KaldiProcessWorker,
Stopped,
log_kaldi_errors,
run_kaldi_function,
thirdparty_binary,
)
if TYPE_CHECKING:
from montreal_forced_aligner.abc import MetaDict
__all__ = ["Transcriber", "TranscriberMixin"]
logger = logging.getLogger("mfa")
[docs]
class TranscriberMixin(CorpusAligner):
"""Abstract class for MFA transcribers
Parameters
----------
transition_scale: float
Transition scale, defaults to 1.0
acoustic_scale: float
Acoustic scale, defaults to 0.1
self_loop_scale: float
Self-loop scale, defaults to 0.1
beam: int
Size of the beam to use in decoding, defaults to 10
silence_weight: float
Weight on silence in fMLLR estimation
max_active: int
Max active for decoding
lattice_beam: int
Beam width for decoding lattices
first_beam: int
Beam for decoding in initial speaker-independent pass, only used if ``uses_speaker_adaptation`` is true
first_max_active: int
Max active for decoding in initial speaker-independent pass, only used if ``uses_speaker_adaptation`` is true
language_model_weight: float
Weight of language model
word_insertion_penalty: float
Penalty for inserting words
"""
def __init__(
self,
transition_scale: float = 1.0,
acoustic_scale: float = 0.083333,
self_loop_scale: float = 0.1,
beam: int = 10,
silence_weight: float = 0.01,
first_beam: int = 10,
first_max_active: int = 2000,
language_model_weight: int = 10,
word_insertion_penalty: float = 0.5,
evaluation_mode: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.beam = beam
self.acoustic_scale = acoustic_scale
self.self_loop_scale = self_loop_scale
self.transition_scale = transition_scale
self.silence_weight = silence_weight
self.first_beam = first_beam
self.first_max_active = first_max_active
self.language_model_weight = language_model_weight
self.word_insertion_penalty = word_insertion_penalty
self.evaluation_mode = evaluation_mode
self.alignment_mode = False
[docs]
def train_speaker_lm_arguments(
self,
) -> List[TrainSpeakerLmArguments]:
"""
Generate Job arguments for :class:`~montreal_forced_aligner.language_modeling.multiprocessing.TrainSpeakerLmFunction`
Returns
-------
list[:class:`~montreal_forced_aligner.language_modeling.multiprocessing.TrainSpeakerLmArguments`]
Arguments for processing
"""
arguments = []
with self.session() as session:
for j in self.jobs:
speaker_mapping = {}
speaker_paths = {}
words_symbol_paths = {}
speakers = (
session.query(Speaker)
.join(Speaker.utterances)
.options(joinedload(Speaker.dictionary, innerjoin=True))
.filter(Utterance.job_id == j.id)
.distinct()
)
for s in speakers:
dict_id = s.dictionary_id
if dict_id not in speaker_mapping:
speaker_mapping[dict_id] = []
words_symbol_paths[dict_id] = s.dictionary.words_symbol_path
speaker_mapping[dict_id].append(s.id)
speaker_paths[s.id] = os.path.join(self.data_directory, f"{s.id}.txt")
arguments.append(
TrainSpeakerLmArguments(
j.id,
getattr(self, "db_string", ""),
self.working_log_directory.joinpath(f"train_lm.{j.id}.log"),
self.model_path,
self.order,
self.method,
self.target_num_ngrams,
self.hclg_options,
)
)
return arguments
[docs]
def train_speaker_lms(self) -> None:
"""Train language models for each speaker based on their utterances"""
begin = time.time()
log_directory = self.model_log_directory
os.makedirs(log_directory, exist_ok=True)
logger.info("Compiling per speaker biased language models...")
arguments = self.train_speaker_lm_arguments()
with tqdm(total=self.num_speakers, disable=GLOBAL_CONFIG.quiet) as pbar:
if GLOBAL_CONFIG.use_mp:
error_dict = {}
return_queue = mp.Queue()
stopped = Stopped()
procs = []
for i, args in enumerate(arguments):
function = TrainSpeakerLmFunction(args)
p = KaldiProcessWorker(i, return_queue, function, stopped)
procs.append(p)
p.start()
while True:
try:
result = return_queue.get(timeout=1)
if stopped.stop_check():
continue
except Empty:
for proc in procs:
if not proc.finished.stop_check():
break
else:
break
continue
if isinstance(result, KaldiProcessingError):
error_dict[result.job_name] = result
continue
pbar.update(1)
if error_dict:
for v in error_dict.values():
raise v
else:
logger.debug("Not using multiprocessing...")
for args in arguments:
function = TrainSpeakerLmFunction(args)
for _ in function.run():
pbar.update(1)
logger.debug(f"Compiling speaker language models took {time.time() - begin:.3f} seconds")
@property
def model_directory(self) -> Path:
"""Model directory for the transcriber"""
return self.output_directory.joinpath("models")
@property
def model_log_directory(self) -> Path:
"""Model directory for the transcriber"""
return self.model_directory.joinpath("log")
[docs]
def lm_rescore(self) -> None:
"""
Rescore lattices with bigger language model
See Also
-------
:class:`~montreal_forced_aligner.transcription.multiprocessing.LmRescoreFunction`
Multiprocessing function
:meth:`.TranscriberMixin.lm_rescore_arguments`
Arguments for function
"""
logger.info("Rescoring lattices with medium G.fst...")
if GLOBAL_CONFIG.use_mp:
error_dict = {}
return_queue = mp.Queue()
stopped = Stopped()
procs = []
for i, args in enumerate(self.lm_rescore_arguments()):
function = LmRescoreFunction(args)
p = KaldiProcessWorker(i, return_queue, function, stopped)
procs.append(p)
p.start()
with tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar:
while True:
try:
result = return_queue.get(timeout=1)
if isinstance(result, Exception):
error_dict[getattr(result, "job_name", 0)] = result
continue
if stopped.stop_check():
continue
except Empty:
for proc in procs:
if not proc.finished.stop_check():
break
else:
break
continue
succeeded, failed = result
if failed:
logger.warning("Some lattices failed to be rescored")
pbar.update(succeeded + failed)
for p in procs:
p.join()
if error_dict:
for v in error_dict.values():
raise v
else:
for args in self.lm_rescore_arguments():
function = LmRescoreFunction(args)
with tqdm(total=GLOBAL_CONFIG.num_jobs, disable=GLOBAL_CONFIG.quiet) as pbar:
for succeeded, failed in function.run():
if failed:
logger.warning("Some lattices failed to be rescored")
pbar.update(succeeded + failed)
[docs]
def carpa_lm_rescore(self) -> None:
"""
Rescore lattices with CARPA language model
See Also
-------
:class:`~montreal_forced_aligner.transcription.multiprocessing.CarpaLmRescoreFunction`
Multiprocessing function
:meth:`.TranscriberMixin.carpa_lm_rescore_arguments`
Arguments for function
"""
logger.info("Rescoring lattices with large G.carpa...")
if GLOBAL_CONFIG.use_mp:
error_dict = {}
return_queue = mp.Queue()
stopped = Stopped()
procs = []
for i, args in enumerate(self.carpa_lm_rescore_arguments()):
function = CarpaLmRescoreFunction(args)
p = KaldiProcessWorker(i, return_queue, function, stopped)
procs.append(p)
p.start()
with tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar:
while True:
try:
result = return_queue.get(timeout=1)
if isinstance(result, Exception):
error_dict[getattr(result, "job_name", 0)] = result
continue
if stopped.stop_check():
continue
except Empty:
for proc in procs:
if not proc.finished.stop_check():
break
else:
break
continue
succeeded, failed = result
if failed:
logger.warning("Some lattices failed to be rescored")
pbar.update(succeeded + failed)
for p in procs:
p.join()
if error_dict:
for v in error_dict.values():
raise v
else:
for args in self.carpa_lm_rescore_arguments():
function = CarpaLmRescoreFunction(args)
with tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar:
for succeeded, failed in function.run():
if failed:
logger.warning("Some lattices failed to be rescored")
pbar.update(succeeded + failed)
[docs]
def train_phone_lm(self):
"""Train a phone-based language model (i.e., not using words)."""
if not self.has_alignments(self.current_workflow.id):
logger.error("Cannot train phone LM without alignments")
return
if self.use_g2p:
return
logger.info("Beginning phone LM training...")
logger.info("Collecting training data...")
ngram_order = 4
num_ngrams = 20000
phone_lm_path = self.phones_dir.joinpath("phone_lm.fst")
log_path = self.phones_dir.joinpath("phone_lm_training.log")
unigram_phones = set()
return_queue = mp.Queue()
stopped = Stopped()
error_dict = {}
procs = []
count_paths = []
allowed_bigrams = collections.defaultdict(set)
with self.session() as session, tqdm(
total=self.num_current_utterances, disable=GLOBAL_CONFIG.quiet
) as pbar:
with mfa_open(self.phones_dir.joinpath("phone_boundaries.int"), "w") as f:
for p in session.query(Phone):
f.write(f"{p.mapping_id} singleton\n")
for j in self.jobs:
args = TrainLmArguments(
j.id,
getattr(self, "db_string", ""),
self.working_log_directory.joinpath(f"ngram_count.{j.id}.log"),
self.phones_dir,
self.phone_symbol_table_path,
ngram_order,
self.oov_word,
)
function = TrainPhoneLmFunction(args)
p = KaldiProcessWorker(j.id, return_queue, function, stopped)
procs.append(p)
p.start()
count_paths.append(self.phones_dir.joinpath(f"{j.id}.cnts"))
while True:
try:
result = return_queue.get(timeout=1)
if isinstance(result, Exception):
error_dict[getattr(result, "job_name", 0)] = result
continue
if stopped.stop_check():
continue
except Empty:
for proc in procs:
if not proc.finished.stop_check():
break
else:
break
continue
_, phones = result
phones = phones.split()
unigram_phones.update(phones)
phones = ["<s>"] + phones + ["</s>"]
for i in range(len(phones) - 1):
allowed_bigrams[phones[i]].add(phones[i + 1])
pbar.update(1)
for p in procs:
p.join()
if error_dict:
for v in error_dict.values():
raise v
logger.info("Training model...")
with mfa_open(log_path, "w") as log_file:
merged_file = self.phones_dir.joinpath("merged.cnts")
if len(count_paths) > 1:
ngrammerge_proc = subprocess.Popen(
[
thirdparty_binary("ngrammerge"),
f"--ofile={merged_file}",
*count_paths,
],
stderr=log_file,
env=os.environ,
)
ngrammerge_proc.communicate()
else:
os.rename(count_paths[0], merged_file)
ngrammake_proc = subprocess.Popen(
[thirdparty_binary("ngrammake"), "--v=2", "--method=kneser_ney", merged_file],
stderr=log_file,
stdout=subprocess.PIPE,
env=os.environ,
)
ngramshrink_proc = subprocess.Popen(
[
thirdparty_binary("ngramshrink"),
"--v=2",
"--method=relative_entropy",
f"--target_number_of_ngrams={num_ngrams}",
],
stderr=log_file,
stdin=ngrammake_proc.stdout,
stdout=subprocess.PIPE,
env=os.environ,
)
print_proc = subprocess.Popen(
[
thirdparty_binary("ngramprint"),
"--ARPA",
f"--symbols={self.phone_symbol_table_path}",
],
stdin=ngramshrink_proc.stdout,
stderr=log_file,
encoding="utf8",
stdout=subprocess.PIPE,
env=os.environ,
)
model = ArpaNgramModel.read(print_proc.stdout)
phone_symbols = pywrapfst.SymbolTable()
for _, phone in sorted(self.reversed_phone_mapping.items()):
phone_symbols.add_symbol(phone)
log_file.write("Done training initial ngram model\n")
log_file.flush()
bigram_fst = model.construct_bigram_fst("#1", allowed_bigrams, phone_symbols)
bigram_fst.write(self.phones_dir.joinpath("bigram.fst"))
bigram_fst.project("output")
push_special_proc = subprocess.Popen(
[thirdparty_binary("fstpushspecial")],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=log_file,
env=os.environ,
)
minimize_proc = subprocess.Popen(
[thirdparty_binary("fstminimizeencoded")],
stdin=push_special_proc.stdout,
stdout=subprocess.PIPE,
stderr=log_file,
env=os.environ,
)
rm_syms_proc = subprocess.Popen(
[
thirdparty_binary("fstrmsymbols"),
"--remove-from-output=true",
self.disambiguation_symbols_int_path,
"-",
phone_lm_path,
],
stdin=minimize_proc.stdout,
stderr=log_file,
env=os.environ,
)
push_special_proc.stdin.write(bigram_fst.write_to_string())
push_special_proc.stdin.flush()
push_special_proc.stdin.close()
rm_syms_proc.communicate()
[docs]
def setup_phone_lm(self) -> None:
"""Setup phone language model for phone-based transcription"""
from montreal_forced_aligner.transcription.multiprocessing import compose_clg, compose_hclg
self.train_phone_lm()
with mfa_open(self.working_log_directory.joinpath("hclg.log"), "w") as log_file:
context_width = self.hclg_options["context_width"]
central_pos = self.hclg_options["central_pos"]
clg_path = os.path.join(
self.working_directory, f"CLG_{context_width}_{central_pos}.fst"
)
hclga_path = self.working_directory.joinpath("HCLGa.fst")
hclg_path = self.working_directory.joinpath("HCLG_phone.fst")
ilabels_temp = os.path.join(
self.working_directory, f"ilabels_{context_width}_{central_pos}"
)
out_disambig = os.path.join(
self.working_directory, f"disambig_ilabels_{context_width}_{central_pos}.int"
)
compose_clg(
self.disambiguation_symbols_int_path,
out_disambig,
context_width,
central_pos,
ilabels_temp,
self.phones_dir.joinpath("phone_lm.fst"),
clg_path,
log_file,
)
log_file.write("Generating HCLGa.fst...")
compose_hclg(
self.model_path,
ilabels_temp,
self.hclg_options["transition_scale"],
clg_path,
hclga_path,
log_file,
)
log_file.write("Generating HCLG.fst...")
self_loop_proc = subprocess.Popen(
[
thirdparty_binary("add-self-loops"),
f"--self-loop-scale={self.hclg_options['self_loop_scale']}",
"--reorder=true",
self.model_path,
hclga_path,
],
stderr=log_file,
stdout=subprocess.PIPE,
env=os.environ,
)
convert_proc = subprocess.Popen(
[
thirdparty_binary("fstconvert"),
"--v=100",
"--fst_type=const",
"-",
hclg_path,
],
stdin=self_loop_proc.stdout,
stderr=log_file,
env=os.environ,
)
convert_proc.communicate()
def transcribe(self, workflow_type: WorkflowType = WorkflowType.transcription):
self.initialize_database()
previous_working_directory = self.working_directory
self.create_new_current_workflow(workflow_type)
if workflow_type is WorkflowType.phone_transcription:
self.setup_phone_lm()
for a in self.calc_fmllr_arguments():
for p in a.trans_paths.values():
shutil.copyfile(previous_working_directory.joinpath(p.name), p)
elif workflow_type is WorkflowType.per_speaker_transcription:
for a in self.calc_fmllr_arguments():
for p in a.trans_paths.values():
if os.path.exists(p):
shutil.copyfile(previous_working_directory.joinpath(p.name), p)
self.acoustic_model.export_model(self.working_directory)
self.transcribe_utterances()
[docs]
def transcribe_utterances(self) -> None:
"""
Transcribe the corpus
See Also
--------
:func:`~montreal_forced_aligner.transcription.multiprocessing.DecodeFunction`
Multiprocessing helper function for each job
:func:`~montreal_forced_aligner.transcription.multiprocessing.LmRescoreFunction`
Multiprocessing helper function for each job
:func:`~montreal_forced_aligner.transcription.multiprocessing.CarpaLmRescoreFunction`
Multiprocessing helper function for each job
Raises
------
:class:`~montreal_forced_aligner.exceptions.KaldiProcessingError`
If there were any errors in running Kaldi binaries
"""
logger.info("Beginning transcription...")
workflow = self.current_workflow
if workflow.done:
logger.info("Transcription already done, skipping!")
return
try:
if workflow.workflow_type is WorkflowType.transcription:
self.uses_speaker_adaptation = False
self.decode()
if workflow.workflow_type is WorkflowType.transcription:
done = True
for a in self.carpa_lm_rescore_arguments():
for p in a.rescored_lat_paths.values():
if not os.path.exists(p):
done = False
break
if done:
logger.info("Rescoring already done.")
else:
logger.info("Performing speaker adjusted transcription...")
self.transcribe_fmllr()
self.lm_rescore()
self.carpa_lm_rescore()
self.collect_alignments()
if self.fine_tune:
self.fine_tune_alignments()
if self.evaluation_mode:
os.makedirs(self.working_log_directory, exist_ok=True)
self.evaluate_transcriptions()
with self.session() as session:
session.query(CorpusWorkflow).filter(CorpusWorkflow.id == workflow.id).update(
{"done": True}
)
session.commit()
except Exception as e:
with self.session() as session:
session.query(CorpusWorkflow).filter(CorpusWorkflow.id == workflow.id).update(
{"dirty": True}
)
session.commit()
if isinstance(e, KaldiProcessingError):
log_kaldi_errors(e.error_logs)
e.update_log_file()
raise
[docs]
def evaluate_transcriptions(self) -> Tuple[float, float]:
"""
Evaluates the transcripts if there are reference transcripts
Returns
-------
float, float
Sentence error rate and word error rate
Raises
------
:class:`~montreal_forced_aligner.exceptions.KaldiProcessingError`
If there were any errors in running Kaldi binaries
"""
logger.info("Evaluating transcripts...")
ser, wer, cer = self.compute_wer()
logger.info(f"SER: {100 * ser:.2f}%, WER: {100 * wer:.2f}%, CER: {100 * cer:.2f}%")
[docs]
def save_transcription_evaluation(self, output_directory: Path) -> None:
"""
Save transcription evaluation to an output directory
Parameters
----------
output_directory: str
Directory to save evaluation
"""
output_path = output_directory.joinpath("transcription_evaluation.csv")
with mfa_open(output_path, "w") as f, self.session() as session:
writer = csv.writer(f)
writer.writerow(
[
"file",
"speaker",
"begin",
"end",
"duration",
"word_count",
"oov_count",
"gold_transcript",
"hypothesis",
"WER",
"CER",
]
)
utterances = (
session.query(
Speaker.name,
File.name,
Utterance.begin,
Utterance.end,
Utterance.duration,
Utterance.normalized_text,
Utterance.transcription_text,
Utterance.oovs,
Utterance.word_error_rate,
Utterance.character_error_rate,
)
.join(Utterance.speaker)
.join(Utterance.file)
.filter(Utterance.normalized_text != None) # noqa
.filter(Utterance.normalized_text != "")
)
for (
speaker,
file,
begin,
end,
duration,
text,
transcription_text,
oovs,
word_error_rate,
character_error_rate,
) in utterances:
word_count = text.count(" ") + 1
oov_count = oovs.count(" ") + 1
writer.writerow(
[
file,
speaker,
begin,
end,
duration,
word_count,
oov_count,
text,
transcription_text,
word_error_rate,
character_error_rate,
]
)
[docs]
def compute_wer(self) -> typing.Tuple[float, float, float]:
"""
Evaluates the transcripts if there are reference transcripts
Raises
------
:class:`~montreal_forced_aligner.exceptions.KaldiProcessingError`
If there were any errors in running Kaldi binaries
"""
if not hasattr(self, "db_engine"):
raise Exception("Must be used as part of a class with a database engine")
logger.info("Evaluating transcripts...")
# Sentence-level measures
incorrect = 0
total_count = 0
# Word-level measures
total_word_edits = 0
total_word_length = 0
# Character-level measures
total_character_edits = 0
total_character_length = 0
indices = []
to_comp = []
update_mappings = []
with self.session() as session:
utterances = session.query(Utterance)
utterances = utterances.filter(Utterance.normalized_text != None) # noqa
utterances = utterances.filter(Utterance.normalized_text != "")
for utt in utterances:
g = utt.normalized_text.split()
total_count += 1
total_word_length += len(g)
character_length = len("".join(g))
total_character_length += character_length
if not utt.transcription_text:
incorrect += 1
total_word_edits += len(g)
total_character_edits += character_length
update_mappings.append(
{"id": utt.id, "word_error_rate": 1.0, "character_error_rate": 1.0}
)
continue
h = utt.transcription_text.split()
if g != h:
indices.append(utt.id)
to_comp.append((g, h))
incorrect += 1
else:
update_mappings.append(
{"id": utt.id, "word_error_rate": 0.0, "character_error_rate": 0.0}
)
with mp.Pool(GLOBAL_CONFIG.num_jobs) as pool:
gen = pool.starmap(score_wer, to_comp)
for i, (word_edits, word_length, character_edits, character_length) in enumerate(
gen
):
utt_id = indices[i]
update_mappings.append(
{
"id": utt_id,
"word_error_rate": word_edits / word_length,
"character_error_rate": character_edits / character_length,
}
)
total_word_edits += word_edits
total_character_edits += character_edits
bulk_update(session, Utterance, update_mappings)
session.commit()
ser = incorrect / total_count
wer = total_word_edits / total_word_length
cer = total_character_edits / total_character_length
return ser, wer, cer
@property
def transcribe_fmllr_options(self) -> MetaDict:
"""Options needed for calculating fMLLR transformations"""
return {
"acoustic_scale": self.acoustic_scale,
"silence_weight": self.silence_weight,
"lattice_beam": self.lattice_beam,
}
@property
def lm_rescore_options(self) -> MetaDict:
"""Options needed for rescoring the language model"""
return {
"acoustic_scale": self.acoustic_scale,
}
[docs]
def decode(self) -> None:
"""
Generate lattices
See Also
-------
:class:`~montreal_forced_aligner.transcription.multiprocessing.DecodeFunction`
Multiprocessing function
:meth:`.TranscriberMixin.decode_arguments`
Arguments for function
"""
logger.info("Generating lattices...")
with tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar:
workflow = self.current_workflow
arguments = self.decode_arguments(workflow.workflow_type)
log_likelihood_sum = 0
log_likelihood_count = 0
if workflow.workflow_type is WorkflowType.per_speaker_transcription:
decode_function = PerSpeakerDecodeFunction
elif workflow.workflow_type is WorkflowType.phone_transcription:
decode_function = DecodePhoneFunction
else:
decode_function = DecodeFunction
for _, log_likelihood, _ in run_kaldi_function(
decode_function, arguments, pbar.update
):
log_likelihood_sum += log_likelihood
log_likelihood_count += 1
if log_likelihood_count:
with self.session() as session:
workflow.score = log_likelihood_sum / log_likelihood_count
session.commit()
[docs]
def calc_initial_fmllr(self) -> None:
"""
Calculate initial fMLLR transforms
See Also
-------
:class:`~montreal_forced_aligner.transcription.multiprocessing.InitialFmllrFunction`
Multiprocessing function
:meth:`.TranscriberMixin.initial_fmllr_arguments`
Arguments for function
"""
logger.info("Calculating initial fMLLR transforms...")
sum_errors = 0
with tqdm(total=self.num_speakers, disable=GLOBAL_CONFIG.quiet) as pbar:
if GLOBAL_CONFIG.use_mp:
error_dict = {}
return_queue = mp.Queue()
stopped = Stopped()
procs = []
for i, args in enumerate(self.initial_fmllr_arguments()):
function = InitialFmllrFunction(args)
p = KaldiProcessWorker(i, return_queue, function, stopped)
procs.append(p)
p.start()
while True:
try:
result = return_queue.get(timeout=1)
if isinstance(result, Exception):
error_dict[getattr(result, "job_name", 0)] = result
continue
if stopped.stop_check():
continue
except Empty:
for proc in procs:
if not proc.finished.stop_check():
break
else:
break
continue
pbar.update(1)
for p in procs:
p.join()
if error_dict:
for v in error_dict.values():
raise v
else:
for args in self.initial_fmllr_arguments():
function = InitialFmllrFunction(args)
for _ in function.run():
pbar.update(1)
if sum_errors:
logger.warning(f"{sum_errors} utterances had errors on calculating fMLLR.")
[docs]
def lat_gen_fmllr(self) -> None:
"""
Generate lattice with fMLLR transforms
See Also
-------
:class:`~montreal_forced_aligner.transcription.multiprocessing.LatGenFmllrFunction`
Multiprocessing function
:meth:`.TranscriberMixin.lat_gen_fmllr_arguments`
Arguments for function
"""
logger.info("Regenerating lattices with fMLLR transforms...")
workflow = self.current_workflow
arguments = self.lat_gen_fmllr_arguments(workflow.workflow_type)
with tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar, mfa_open(
self.working_log_directory.joinpath("lat_gen_fmllr_log_like.csv"),
"w",
encoding="utf8",
) as log_file:
log_file.write("utterance,log_likelihood,num_frames\n")
if GLOBAL_CONFIG.use_mp:
error_dict = {}
return_queue = mp.Queue()
stopped = Stopped()
procs = []
for i, args in enumerate(arguments):
function = LatGenFmllrFunction(args)
p = KaldiProcessWorker(i, return_queue, function, stopped)
procs.append(p)
p.start()
while True:
try:
result = return_queue.get(timeout=1)
if isinstance(result, Exception):
error_dict[getattr(result, "job_name", 0)] = result
continue
if stopped.stop_check():
continue
except Empty:
for proc in procs:
if not proc.finished.stop_check():
break
else:
break
continue
pbar.update(1)
utterance, log_likelihood, num_frames = result
log_file.write(f"{utterance},{log_likelihood},{num_frames}\n")
for p in procs:
p.join()
if error_dict:
for v in error_dict.values():
raise v
else:
for args in arguments:
function = LatGenFmllrFunction(args)
for utterance, log_likelihood, num_frames in function.run():
log_file.write(f"{utterance},{log_likelihood},{num_frames}\n")
pbar.update(1)
[docs]
def calc_final_fmllr(self) -> None:
"""
Calculate final fMLLR transforms
See Also
-------
:class:`~montreal_forced_aligner.transcription.multiprocessing.FinalFmllrFunction`
Multiprocessing function
:meth:`.TranscriberMixin.final_fmllr_arguments`
Arguments for function
"""
logger.info("Calculating final fMLLR transforms...")
sum_errors = 0
with tqdm(total=self.num_speakers, disable=GLOBAL_CONFIG.quiet) as pbar:
if GLOBAL_CONFIG.use_mp:
error_dict = {}
return_queue = mp.Queue()
stopped = Stopped()
procs = []
for i, args in enumerate(self.final_fmllr_arguments()):
function = FinalFmllrFunction(args)
p = KaldiProcessWorker(i, return_queue, function, stopped)
procs.append(p)
p.start()
while True:
try:
result = return_queue.get(timeout=1)
if isinstance(result, Exception):
error_dict[getattr(result, "job_name", 0)] = result
continue
if stopped.stop_check():
continue
except Empty:
for proc in procs:
if not proc.finished.stop_check():
break
else:
break
continue
pbar.update(1)
for p in procs:
p.join()
if error_dict:
for v in error_dict.values():
raise v
else:
for args in self.final_fmllr_arguments():
function = FinalFmllrFunction(args)
for _ in function.run():
pbar.update(1)
if sum_errors:
logger.warning(f"{sum_errors} utterances had errors on calculating fMLLR.")
[docs]
def fmllr_rescore(self) -> None:
"""
Rescore lattices with final fMLLR transforms
See Also
-------
:class:`~montreal_forced_aligner.transcription.multiprocessing.FmllrRescoreFunction`
Multiprocessing function
:meth:`.TranscriberMixin.fmllr_rescore_arguments`
Arguments for function
"""
logger.info("Rescoring fMLLR lattices with final transform...")
sum_errors = 0
with tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar:
if GLOBAL_CONFIG.use_mp:
error_dict = {}
return_queue = mp.Queue()
stopped = Stopped()
procs = []
for i, args in enumerate(self.fmllr_rescore_arguments()):
function = FmllrRescoreFunction(args)
p = KaldiProcessWorker(i, return_queue, function, stopped)
procs.append(p)
p.start()
while True:
try:
result = return_queue.get(timeout=1)
if isinstance(result, Exception):
error_dict[getattr(result, "job_name", 0)] = result
continue
if stopped.stop_check():
continue
except Empty:
for proc in procs:
if not proc.finished.stop_check():
break
else:
break
continue
done, errors = result
sum_errors += errors
pbar.update(done + errors)
for p in procs:
p.join()
if error_dict:
for v in error_dict.values():
raise v
else:
for args in self.fmllr_rescore_arguments():
function = FmllrRescoreFunction(args)
for done, errors in function.run():
sum_errors += errors
pbar.update(done + errors)
if sum_errors:
logger.warning(f"{errors} utterances had errors on calculating fMLLR.")
[docs]
def transcribe_fmllr(self) -> None:
"""
Run fMLLR estimation over initial decoding lattices and rescore
See Also
--------
:func:`~montreal_forced_aligner.transcription.multiprocessing.InitialFmllrFunction`
Multiprocessing helper function for each job
:func:`~montreal_forced_aligner.transcription.multiprocessing.LatGenFmllrFunction`
Multiprocessing helper function for each job
:func:`~montreal_forced_aligner.transcription.multiprocessing.FinalFmllrFunction`
Multiprocessing helper function for each job
:func:`~montreal_forced_aligner.transcription.multiprocessing.FmllrRescoreFunction`
Multiprocessing helper function for each job
:func:`~montreal_forced_aligner.transcription.multiprocessing.LmRescoreFunction`
Multiprocessing helper function for each job
:func:`~montreal_forced_aligner.transcription.multiprocessing.CarpaLmRescoreFunction`
Multiprocessing helper function for each job
"""
workflow = self.current_workflow
self.calc_initial_fmllr()
self.uses_speaker_adaptation = True
self.lat_gen_fmllr()
self.calc_final_fmllr()
for decode_args in self.decode_arguments(workflow.workflow_type):
for lat_path in decode_args.lat_paths.values():
os.remove(lat_path)
self.fmllr_rescore()
[docs]
def decode_arguments(
self, workflow: WorkflowType = WorkflowType.transcription
) -> List[typing.Union[DecodeArguments, PerSpeakerDecodeArguments]]:
"""
Generate Job arguments for :class:`~montreal_forced_aligner.transcription.multiprocessing.DecodeFunction`
Returns
-------
list[:class:`~montreal_forced_aligner.transcription.multiprocessing.DecodeArguments`]
Arguments for processing
"""
arguments = []
for j in self.jobs:
feat_strings = {}
for d_id in j.dictionary_ids:
feat_strings[d_id] = j.construct_feature_proc_string(
self.working_directory,
d_id,
self.feature_options["uses_splices"],
self.feature_options["splice_left_context"],
self.feature_options["splice_right_context"],
self.feature_options["uses_speaker_adaptation"],
)
if workflow is WorkflowType.per_speaker_transcription:
arguments.append(
PerSpeakerDecodeArguments(
j.id,
getattr(self, "db_string", ""),
self.working_log_directory.joinpath(f"per_speaker_decode.{j.id}.log"),
self.model_directory,
feat_strings,
j.construct_path_dictionary(self.working_directory, "lat", "ark"),
self.model_path,
self.disambiguation_symbols_int_path,
self.decode_options,
self.tree_path,
self.order,
self.method,
)
)
elif workflow is WorkflowType.phone_transcription:
arguments.append(
DecodePhoneArguments(
j.id,
getattr(self, "db_string", ""),
self.working_log_directory.joinpath(f"decode.{j.id}.log"),
j.dictionary_ids,
feat_strings,
self.decode_options,
self.alignment_model_path,
j.construct_path_dictionary(self.working_directory, "lat", "ark"),
self.phone_symbol_table_path,
self.working_directory.joinpath("HCLG_phone.fst"),
)
)
else:
decode_options = self.decode_options
decode_options["max_active"] = decode_options["first_max_active"]
decode_options["beam"] = decode_options["first_beam"]
arguments.append(
DecodeArguments(
j.id,
getattr(self, "db_string", ""),
self.working_log_directory.joinpath(f"decode.{j.id}.log"),
j.dictionary_ids,
feat_strings,
decode_options,
self.alignment_model_path,
j.construct_path_dictionary(self.working_directory, "lat", "ark"),
j.construct_dictionary_dependent_paths(
self.model_directory, "words", "txt"
),
j.construct_dictionary_dependent_paths(
self.model_directory, "HCLG", "fst"
),
)
)
return arguments
[docs]
def lm_rescore_arguments(self) -> List[LmRescoreArguments]:
"""
Generate Job arguments for :class:`~montreal_forced_aligner.transcription.multiprocessing.LmRescoreFunction`
Returns
-------
list[:class:`~montreal_forced_aligner.transcription.multiprocessing.LmRescoreArguments`]
Arguments for processing
"""
return [
LmRescoreArguments(
j.id,
getattr(self, "db_string", ""),
self.working_log_directory.joinpath(f"lm_rescore.{j.id}.log"),
j.dictionary_ids,
self.lm_rescore_options,
j.construct_path_dictionary(self.working_directory, "lat", "ark"),
j.construct_path_dictionary(self.working_directory, "lat.rescored", "ark"),
j.construct_dictionary_dependent_paths(self.model_directory, "G_small", "fst"),
j.construct_dictionary_dependent_paths(self.model_directory, "G_med", "fst"),
)
for j in self.jobs
]
[docs]
def carpa_lm_rescore_arguments(self) -> List[CarpaLmRescoreArguments]:
"""
Generate Job arguments for :class:`~montreal_forced_aligner.transcription.multiprocessing.CarpaLmRescoreFunction`
Returns
-------
list[:class:`~montreal_forced_aligner.transcription.multiprocessing.CarpaLmRescoreArguments`]
Arguments for processing
"""
return [
CarpaLmRescoreArguments(
j.id,
getattr(self, "db_string", ""),
self.working_log_directory.joinpath(f"carpa_lm_rescore.{j.id}.log"),
j.dictionary_ids,
j.construct_path_dictionary(self.working_directory, "lat.rescored", "ark"),
j.construct_path_dictionary(self.working_directory, "lat.carpa.rescored", "ark"),
j.construct_dictionary_dependent_paths(self.model_directory, "G_med", "fst"),
j.construct_dictionary_dependent_paths(self.model_directory, "G", "carpa"),
)
for j in self.jobs
]
@property
def fmllr_options(self) -> MetaDict:
"""Options for calculating fMLLR"""
options = super().fmllr_options
options["acoustic_scale"] = self.acoustic_scale
options["sil_phones"] = self.silence_csl
options["lattice_beam"] = self.lattice_beam
return options
[docs]
def initial_fmllr_arguments(self) -> List[InitialFmllrArguments]:
"""
Generate Job arguments for :class:`~montreal_forced_aligner.transcription.multiprocessing.InitialFmllrFunction`
Returns
-------
list[:class:`~montreal_forced_aligner.transcription.multiprocessing.InitialFmllrArguments`]
Arguments for processing
"""
arguments = []
for j in self.jobs:
feat_strings = {}
for d_id in j.dictionary_ids:
feat_strings[d_id] = j.construct_feature_proc_string(
self.working_directory,
d_id,
self.feature_options["uses_splices"],
self.feature_options["splice_left_context"],
self.feature_options["splice_right_context"],
self.feature_options["uses_speaker_adaptation"],
)
arguments.append(
InitialFmllrArguments(
j.id,
getattr(self, "db_string", ""),
self.working_log_directory.joinpath(f"initial_fmllr.{j.id}.log"),
j.dictionary_ids,
feat_strings,
self.model_path,
self.fmllr_options,
j.construct_path_dictionary(self.working_directory, "trans", "ark"),
j.construct_path_dictionary(self.working_directory, "lat", "ark"),
j.construct_path_dictionary(self.data_directory, "spk2utt", "scp"),
)
)
return arguments
[docs]
def lat_gen_fmllr_arguments(
self, workflow: WorkflowType = WorkflowType.transcription
) -> List[LatGenFmllrArguments]:
"""
Generate Job arguments for :class:`~montreal_forced_aligner.transcription.multiprocessing.LatGenFmllrFunction`
Returns
-------
list[:class:`~montreal_forced_aligner.transcription.multiprocessing.LatGenFmllrArguments`]
Arguments for processing
"""
arguments = []
for j in self.jobs:
feat_strings = {}
word_paths = {}
hclg_paths = {}
if workflow is not WorkflowType.phone_transcription:
for d in j.dictionaries:
word_paths[d.id] = d.words_symbol_path
hclg_paths[d.id] = os.path.join(self.model_directory, f"HCLG.{d.id}.fst")
feat_strings[d.id] = j.construct_feature_proc_string(
self.working_directory,
d.id,
self.feature_options["uses_splices"],
self.feature_options["splice_left_context"],
self.feature_options["splice_right_context"],
self.feature_options["uses_speaker_adaptation"],
)
else:
hclg_paths = self.working_directory.joinpath("HCLG_phone.fst")
word_paths = self.phone_symbol_table_path
feat_strings = j.construct_feature_proc_string(
self.working_directory,
None,
self.feature_options["uses_splices"],
self.feature_options["splice_left_context"],
self.feature_options["splice_right_context"],
self.feature_options["uses_speaker_adaptation"],
)
arguments.append(
LatGenFmllrArguments(
j.id,
getattr(self, "db_string", ""),
self.working_log_directory.joinpath(f"lat_gen_fmllr.{j.id}.log"),
j.dictionary_ids,
feat_strings,
self.model_path,
self.decode_options,
word_paths,
hclg_paths,
j.construct_path_dictionary(self.working_directory, "lat.tmp", "ark"),
)
)
return arguments
[docs]
def final_fmllr_arguments(self) -> List[FinalFmllrArguments]:
"""
Generate Job arguments for :class:`~montreal_forced_aligner.transcription.multiprocessing.FinalFmllrFunction`
Returns
-------
list[:class:`~montreal_forced_aligner.transcription.multiprocessing.FinalFmllrArguments`]
Arguments for processing
"""
arguments = []
for j in self.jobs:
feat_strings = {}
for d_id in j.dictionary_ids:
feat_strings[d_id] = j.construct_feature_proc_string(
self.working_directory,
d_id,
self.feature_options["uses_splices"],
self.feature_options["splice_left_context"],
self.feature_options["splice_right_context"],
self.feature_options["uses_speaker_adaptation"],
)
arguments.append(
FinalFmllrArguments(
j.id,
getattr(self, "db_string", ""),
self.working_log_directory.joinpath(f"final_fmllr.{j.id}.log"),
j.dictionary_ids,
feat_strings,
self.model_path,
self.fmllr_options,
j.construct_path_dictionary(self.working_directory, "trans", "ark"),
j.construct_path_dictionary(self.data_directory, "spk2utt", "scp"),
j.construct_path_dictionary(self.working_directory, "lat.tmp", "ark"),
)
)
return arguments
[docs]
def fmllr_rescore_arguments(self) -> List[FmllrRescoreArguments]:
"""
Generate Job arguments for :class:`~montreal_forced_aligner.transcription.multiprocessing.FmllrRescoreFunction`
Returns
-------
list[:class:`~montreal_forced_aligner.transcription.multiprocessing.FmllrRescoreArguments`]
Arguments for processing
"""
arguments = []
for j in self.jobs:
feat_strings = {}
for d_id in j.dictionary_ids:
feat_strings[d_id] = j.construct_feature_proc_string(
self.working_directory,
d_id,
self.feature_options["uses_splices"],
self.feature_options["splice_left_context"],
self.feature_options["splice_right_context"],
self.feature_options["uses_speaker_adaptation"],
)
arguments.append(
FmllrRescoreArguments(
j.id,
getattr(self, "db_string", ""),
self.working_log_directory.joinpath(f"fmllr_rescore.{j.id}.log"),
j.dictionary_ids,
feat_strings,
self.model_path,
self.fmllr_options,
j.construct_path_dictionary(self.working_directory, "lat.tmp", "ark"),
j.construct_path_dictionary(self.working_directory, "lat", "ark"),
)
)
return arguments
[docs]
class Transcriber(TranscriberMixin, TopLevelMfaWorker):
"""
Class for performing transcription.
Parameters
----------
acoustic_model_path : str
Path to acoustic model
language_model_path : str
Path to language model model
evaluation_mode: bool
Flag for evaluating generated transcripts against the actual transcripts, defaults to False
See Also
--------
:class:`~montreal_forced_aligner.transcription.transcriber.TranscriberMixin`
For transcription parameters
:class:`~montreal_forced_aligner.corpus.acoustic_corpus.AcousticCorpusPronunciationMixin`
For corpus and dictionary parsing parameters
:class:`~montreal_forced_aligner.abc.FileExporterMixin`
For file exporting parameters
:class:`~montreal_forced_aligner.abc.TopLevelMfaWorker`
For top-level parameters
Attributes
----------
acoustic_model: AcousticModel
Acoustic model
language_model: LanguageModel
Language model
"""
def __init__(
self,
acoustic_model_path: Path,
language_model_path: Path,
output_type: str = "transcription",
**kwargs,
):
self.acoustic_model = AcousticModel(acoustic_model_path)
kwargs.update(self.acoustic_model.parameters)
super(Transcriber, self).__init__(**kwargs)
self.language_model = LanguageModel(language_model_path)
self.output_type = output_type
self.ignore_empty_utterances = False
[docs]
def create_hclgs_arguments(self) -> Dict[int, CreateHclgArguments]:
"""
Generate Job arguments for :class:`~montreal_forced_aligner.transcription.multiprocessing.CreateHclgFunction`
Returns
-------
dict[str, :class:`~montreal_forced_aligner.transcription.multiprocessing.CreateHclgArguments`]
Per dictionary arguments for HCLG
"""
args = {}
with self.session() as session:
for d in session.query(Dictionary):
args[d.id] = CreateHclgArguments(
d.id,
getattr(self, "db_string", ""),
self.model_directory.joinpath("log", f"hclg.{d.id}.log"),
self.model_directory,
self.model_directory.joinpath(f"words.{d.id}.txt"),
self.model_directory.joinpath(f"G.{d.id}.carpa"),
self.language_model.small_arpa_path,
self.language_model.medium_arpa_path,
self.language_model.carpa_path,
self.model_path,
d.lexicon_disambig_fst_path,
d.disambiguation_symbols_int_path,
self.hclg_options,
self.word_mapping(d.id),
)
return args
[docs]
def create_hclgs(self) -> None:
"""
Create HCLG.fst files for every dictionary being used by a
:class:`~montreal_forced_aligner.transcription.transcriber.Transcriber`
"""
dict_arguments = self.create_hclgs_arguments()
dict_arguments = list(dict_arguments.values())
logger.info("Generating HCLG.fst...")
if GLOBAL_CONFIG.use_mp:
error_dict = {}
return_queue = mp.Queue()
stopped = Stopped()
procs = []
for i, args in enumerate(dict_arguments):
function = CreateHclgFunction(args)
p = KaldiProcessWorker(i, return_queue, function, stopped)
procs.append(p)
p.start()
with tqdm(total=len(dict_arguments) * 7, disable=GLOBAL_CONFIG.quiet) as pbar:
while True:
try:
result = return_queue.get(timeout=1)
if isinstance(result, Exception):
error_dict[getattr(result, "job_name", 0)] = result
continue
elif not isinstance(result, tuple):
pbar.update(1)
continue
if stopped.stop_check():
continue
except Empty:
for proc in procs:
if not proc.finished.stop_check():
break
else:
break
continue
result, hclg_path = result
if result:
logger.debug(f"Done generating {hclg_path}!")
else:
logger.warning(f"There was an error in generating {hclg_path}")
pbar.update(1)
for p in procs:
p.join()
if error_dict:
for v in error_dict.values():
raise v
else:
for args in dict_arguments:
function = CreateHclgFunction(args)
with tqdm(total=len(dict_arguments), disable=GLOBAL_CONFIG.quiet) as pbar:
for result in function.run():
if not isinstance(result, tuple):
pbar.update(1)
continue
result, hclg_path = result
if result:
logger.debug(f"Done generating {hclg_path}!")
else:
logger.warning(f"There was an error in generating {hclg_path}")
pbar.update(1)
error_logs = []
for arg in dict_arguments:
if not self.model_directory.joinpath(f"HCLG.{arg.job_name}.fst").exists():
error_logs.append(arg.log_path)
if error_logs:
raise KaldiProcessingError(error_logs)
[docs]
def create_decoding_graph(self) -> None:
"""
Create decoding graph for use in transcription
Raises
------
:class:`~montreal_forced_aligner.exceptions.KaldiProcessingError`
If there were any errors in running Kaldi binaries
"""
done_path = os.path.join(self.model_directory, "done")
if os.path.exists(done_path):
logger.info("Graph construction already done, skipping!")
log_dir = os.path.join(self.model_directory, "log")
os.makedirs(log_dir, exist_ok=True)
self.write_lexicon_information(write_disambiguation=True)
with self.session() as session:
for d in session.query(Dictionary):
words_path = os.path.join(self.model_directory, f"words.{d.id}.txt")
shutil.copyfile(d.words_symbol_path, words_path)
big_arpa_path = self.language_model.carpa_path
small_arpa_path = self.language_model.small_arpa_path
medium_arpa_path = self.language_model.medium_arpa_path
if not os.path.exists(small_arpa_path) or not os.path.exists(medium_arpa_path):
logger.warning(
"Creating small and medium language models from scratch, this may take some time. "
"Running `mfa train_lm` on the ARPA file will remove this warning."
)
logger.info("Parsing large ngram model...")
mod_path = self.model_directory.joinpath("base_lm.mod")
new_carpa_path = os.path.join(self.model_directory, "base_lm.arpa")
with mfa_open(big_arpa_path, "r") as inf, mfa_open(new_carpa_path, "w") as outf:
for line in inf:
outf.write(line.lower())
big_arpa_path = new_carpa_path
subprocess.call(["ngramread", "--ARPA", big_arpa_path, mod_path])
if not os.path.exists(small_arpa_path):
logger.info(
"Generating small model from the large ARPA with a pruning threshold of 3e-7"
)
prune_thresh_small = 0.0000003
small_mod_path = mod_path.with_stem(mod_path.stem + "_small")
subprocess.call(
[
"ngramshrink",
"--method=relative_entropy",
f"--theta={prune_thresh_small}",
mod_path,
small_mod_path,
]
)
subprocess.call(["ngramprint", "--ARPA", small_mod_path, small_arpa_path])
if not os.path.exists(medium_arpa_path):
logger.info(
"Generating medium model from the large ARPA with a pruning threshold of 1e-7"
)
prune_thresh_medium = 0.0000001
med_mod_path = mod_path.with_stem(mod_path.stem + "_med")
subprocess.call(
[
"ngramshrink",
"--method=relative_entropy",
f"--theta={prune_thresh_medium}",
mod_path,
med_mod_path,
]
)
subprocess.call(["ngramprint", "--ARPA", med_mod_path, medium_arpa_path])
try:
self.create_hclgs()
except Exception as e:
dirty_path = os.path.join(self.model_directory, "dirty")
with mfa_open(dirty_path, "w"):
pass
if isinstance(e, KaldiProcessingError):
log_kaldi_errors(e.error_logs)
e.update_log_file()
raise
[docs]
@classmethod
def parse_parameters(
cls,
config_path: Optional[Path] = None,
args: Optional[Dict[str, typing.Any]] = None,
unknown_args: Optional[typing.Iterable[str]] = None,
) -> MetaDict:
"""
Parse configuration parameters from a config file and command line arguments
Parameters
----------
config_path: :class:`~pathlib.Path`, optional
Path to yaml configuration file
args: dict[str, Any]
Parsed arguments
unknown_args: list[str]
Optional list of arguments that were not parsed
Returns
-------
dict[str, Any]
Dictionary of specified configuration parameters
"""
global_params = {}
if config_path and os.path.exists(config_path):
data = load_configuration(config_path)
data = parse_old_features(data)
for k, v in data.items():
if k == "features":
global_params.update(v)
else:
if v is None and k in cls.nullable_fields:
v = []
global_params[k] = v
global_params.update(cls.parse_args(args, unknown_args))
if args.get("language_model_weight", None) is not None:
global_params["min_language_model_weight"] = args["language_model_weight"]
global_params["max_language_model_weight"] = args["language_model_weight"] + 1
if args.get("word_insertion_penalty", None) is not None:
global_params["word_insertion_penalties"] = [args["word_insertion_penalty"]]
return global_params
def setup_acoustic_model(self):
self.acoustic_model.validate(self)
self.acoustic_model.export_model(self.model_directory)
self.acoustic_model.export_model(self.working_directory)
self.acoustic_model.log_details()
[docs]
def setup(self) -> None:
"""Set up transcription"""
self.alignment_mode = False
TopLevelMfaWorker.setup(self)
if self.initialized:
return
self.create_new_current_workflow(WorkflowType.transcription)
begin = time.time()
os.makedirs(self.working_log_directory, exist_ok=True)
self.load_corpus()
dirty_path = self.working_directory.joinpath("dirty")
if os.path.exists(dirty_path):
shutil.rmtree(self.working_directory, ignore_errors=True)
os.makedirs(self.working_log_directory, exist_ok=True)
dirty_path = os.path.join(self.model_directory, "dirty")
if os.path.exists(dirty_path): # if there was an error, let's redo from scratch
shutil.rmtree(self.model_directory)
log_dir = os.path.join(self.model_directory, "log")
os.makedirs(log_dir, exist_ok=True)
self.setup_acoustic_model()
self.create_decoding_graph()
self.initialized = True
logger.debug(f"Setup for transcription in {time.time() - begin:.3f} seconds")
[docs]
def export_transcriptions(self) -> None:
"""Export transcriptions"""
with self.session() as session:
files = session.query(File).options(
selectinload(File.utterances),
selectinload(File.speakers),
joinedload(File.sound_file, innerjoin=True).load_only(SoundFile.duration),
)
for file in files:
utterance_count = len(file.utterances)
duration = file.sound_file.duration
if utterance_count == 0:
logger.debug(f"Could not find any utterances for {file.name}")
elif (
utterance_count == 1
and file.utterances[0].begin == 0
and file.utterances[0].end == duration
):
output_format = "lab"
else:
output_format = TextgridFormats.SHORT_TEXTGRID
output_path = construct_output_path(
file.name,
file.relative_path,
self.export_output_directory,
output_format=output_format,
)
data = file.construct_transcription_tiers()
if output_format == "lab":
for intervals in data.values():
with mfa_open(output_path, "w") as f:
f.write(intervals["transcription"][0].label)
else:
tg = textgrid.Textgrid()
tg.minTimestamp = 0
tg.maxTimestamp = round(duration, 5)
for speaker in file.speakers:
speaker = speaker.name
intervals = data[speaker]["transcription"]
tier = textgrid.IntervalTier(
speaker,
[x.to_tg_interval() for x in intervals],
minT=0,
maxT=round(duration, 5),
)
tg.addTier(tier)
tg.save(output_path, includeBlankSpaces=True, format=output_format)
if self.evaluation_mode:
self.save_transcription_evaluation(self.export_output_directory)
[docs]
def export_files(
self,
output_directory: Path,
output_format: Optional[str] = None,
include_original_text: bool = False,
) -> None:
"""
Export transcriptions
Parameters
----------
output_directory: str
Directory to save transcriptions
output_format: str, optional
Format to save alignments, one of 'long_textgrids' (the default), 'short_textgrids', or 'json', passed to praatio
"""
if output_format is None:
output_format = TextFileType.TEXTGRID.value
self.export_output_directory = output_directory
os.makedirs(self.export_output_directory, exist_ok=True)
if self.output_type == "transcription":
self.export_transcriptions()
else:
self.export_textgrids(output_format, include_original_text)