Source code for montreal_forced_aligner.language_modeling.multiprocessing

"""Multiprocessing functions for training language models"""
from __future__ import annotations

import os
import subprocess
import typing
from pathlib import Path

import sqlalchemy
from sqlalchemy.orm import Session, joinedload, subqueryload

from montreal_forced_aligner.config import GLOBAL_CONFIG
from montreal_forced_aligner.data import MfaArguments, WordType
from montreal_forced_aligner.db import Job, Phone, PhoneInterval, Speaker, Utterance, Word
from montreal_forced_aligner.helper import mfa_open
from montreal_forced_aligner.transcription.multiprocessing import (
    compose_clg,
    compose_hclg,
    compose_lg,
)
from montreal_forced_aligner.utils import KaldiFunction, thirdparty_binary

if typing.TYPE_CHECKING:
    from dataclasses import dataclass

    from montreal_forced_aligner.abc import MetaDict

else:
    from dataclassy import dataclass


[docs] @dataclass class TrainSpeakerLmArguments(MfaArguments): """ Arguments for :class:`~montreal_forced_aligner.language_modeling.multiprocessing.TrainSpeakerLmFunction` Parameters ---------- job_name: int Integer ID of the job db_string: str String for database connections log_path: :class:`~pathlib.Path` Path to save logging information during the run model_path: :class:`~pathlib.Path` Path to model order: int Ngram order of the language models method: str Ngram smoothing method target_num_ngrams: int Target number of ngrams hclg_options: dict[str, Any] HCLG creation options """ model_path: Path order: int method: str target_num_ngrams: int hclg_options: MetaDict
@dataclass class TrainLmArguments(MfaArguments): """ Arguments for :class:`~montreal_forced_aligner.language_modeling.multiprocessing.TrainSpeakerLmFunction` Parameters ---------- job_name: int Integer ID of the job db_string: str String for database connections log_path: :class:`~pathlib.Path` Path to save logging information during the run symbols_path: :class:`~pathlib.Path` Words symbol table paths oov_word: str OOV word order: int Ngram order of the language models """ working_directory: Path symbols_path: Path order: int oov_word: str class TrainLmFunction(KaldiFunction): """ Multiprocessing function to training small language models for each speaker See Also -------- :openfst_src:`farcompilestrings` Relevant OpenFst binary :ngram_src:`ngramcount` Relevant OpenGrm-Ngram binary :ngram_src:`ngrammake` Relevant OpenGrm-Ngram binary :ngram_src:`ngramshrink` Relevant OpenGrm-Ngram binary Parameters ---------- args: :class:`~montreal_forced_aligner.language_modeling.multiprocessing.TrainSpeakerLmArguments` Arguments for the function """ def __init__(self, args: TrainLmArguments): super().__init__(args) self.working_directory = args.working_directory self.symbols_path = args.symbols_path self.order = args.order self.oov_word = args.oov_word def _run(self) -> typing.Generator[bool]: """Run the function""" with Session(self.db_engine()) as session, mfa_open(self.log_path, "w") as log_file: word_query = session.query(Word.word).filter(Word.word_type == WordType.speech) included_words = set(x[0] for x in word_query) utterance_query = session.query(Utterance.normalized_text, Utterance.text).filter( Utterance.job_id == self.job_name ) farcompile_proc = subprocess.Popen( [ thirdparty_binary("farcompilestrings"), "--token_type=symbol", "--generate_keys=16", "--keep_symbols", f"--symbols={self.symbols_path}", ], stderr=log_file, stdin=subprocess.PIPE, stdout=subprocess.PIPE, env=os.environ, ) ngramcount_proc = subprocess.Popen( [ thirdparty_binary("ngramcount"), "--round_to_int", f"--order={self.order}", "-", self.working_directory.joinpath(f"{self.job_name}.cnts"), ], stderr=log_file, stdin=farcompile_proc.stdout, env=os.environ, ) for (normalized_text, text) in utterance_query: if not normalized_text: normalized_text = text text = " ".join( x if x in included_words else self.oov_word for x in normalized_text.split() ) farcompile_proc.stdin.write(f"{text}\n".encode("utf8")) farcompile_proc.stdin.flush() yield 1 farcompile_proc.stdin.close() self.check_call(ngramcount_proc) class TrainPhoneLmFunction(KaldiFunction): """ Multiprocessing function to training small language models for each speaker See Also -------- :openfst_src:`farcompilestrings` Relevant OpenFst binary :ngram_src:`ngramcount` Relevant OpenGrm-Ngram binary :ngram_src:`ngrammake` Relevant OpenGrm-Ngram binary :ngram_src:`ngramshrink` Relevant OpenGrm-Ngram binary Parameters ---------- args: :class:`~montreal_forced_aligner.language_modeling.multiprocessing.TrainSpeakerLmArguments` Arguments for the function """ def __init__(self, args: TrainLmArguments): super().__init__(args) self.working_directory = args.working_directory self.symbols_path = args.symbols_path self.order = args.order def _run(self) -> typing.Generator[bool]: """Run the function""" with Session(self.db_engine()) as session, mfa_open(self.log_path, "w") as log_file: if GLOBAL_CONFIG.current_profile.use_postgres: string_agg_function = sqlalchemy.func.string_agg else: string_agg_function = sqlalchemy.func.group_concat pronunciation_query = ( sqlalchemy.select(Utterance.id, string_agg_function(Phone.kaldi_label, " ")) .select_from(Utterance) .join(Utterance.phone_intervals) .join(PhoneInterval.phone) .where(Utterance.job_id == self.job_name) .group_by(Utterance.id) ) farcompile_proc = subprocess.Popen( [ thirdparty_binary("farcompilestrings"), "--token_type=symbol", "--generate_keys=16", f"--symbols={self.symbols_path}", ], stderr=log_file, stdin=subprocess.PIPE, stdout=subprocess.PIPE, env=os.environ, ) ngramcount_proc = subprocess.Popen( [ thirdparty_binary("ngramcount"), "--require_symbols=false", "--round_to_int", f"--order={self.order}", "-", self.working_directory.joinpath(f"{self.job_name}.cnts"), ], stderr=log_file, stdin=farcompile_proc.stdout, env=os.environ, ) for utt_id, phones in session.execute(pronunciation_query): farcompile_proc.stdin.write(f"{phones}\n".encode("utf8")) farcompile_proc.stdin.flush() yield utt_id, phones farcompile_proc.stdin.close() self.check_call(ngramcount_proc)
[docs] class TrainSpeakerLmFunction(KaldiFunction): """ Multiprocessing function to training small language models for each speaker See Also -------- :openfst_src:`farcompilestrings` Relevant OpenFst binary :ngram_src:`ngramcount` Relevant OpenGrm-Ngram binary :ngram_src:`ngrammake` Relevant OpenGrm-Ngram binary :ngram_src:`ngramshrink` Relevant OpenGrm-Ngram binary Parameters ---------- args: :class:`~montreal_forced_aligner.language_modeling.multiprocessing.TrainSpeakerLmArguments` Arguments for the function """ def __init__(self, args: TrainSpeakerLmArguments): super().__init__(args) self.model_path = args.model_path self.order = args.order self.method = args.method self.target_num_ngrams = args.target_num_ngrams self.hclg_options = args.hclg_options def _run(self) -> typing.Generator[bool]: """Run the function""" with Session(self.db_engine()) as session, mfa_open(self.log_path, "w") as log_file: job: Job = ( session.query(Job) .options(joinedload(Job.corpus, innerjoin=True), subqueryload(Job.dictionaries)) .filter(Job.id == self.job_name) .first() ) for d in job.dictionaries: dict_id = d.id word_symbols_path = d.words_symbol_path speakers = ( session.query(Speaker.id) .join(Utterance.speaker) .filter(Utterance.job_id == job.id) .filter(Speaker.dictionary_id == dict_id) .distinct() ) for (speaker_id,) in speakers: hclg_path = d.temp_directory.joinpath(f"{speaker_id}.fst") if os.path.exists(hclg_path): continue utterances = ( session.query(Utterance.normalized_text) .filter(Utterance.speaker_id == speaker_id) .order_by(Utterance.kaldi_id) ) mod_path = d.temp_directory.joinpath(f"g.{speaker_id}.fst") farcompile_proc = subprocess.Popen( [ thirdparty_binary("farcompilestrings"), "--fst_type=compact", f"--unknown_symbol={d.oov_word}", f"--symbols={word_symbols_path}", "--keep_symbols", "--generate_keys=16", ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=log_file, env=os.environ, ) count_proc = subprocess.Popen( [thirdparty_binary("ngramcount"), f"--order={self.order}"], stdin=farcompile_proc.stdout, stdout=subprocess.PIPE, stderr=log_file, ) make_proc = subprocess.Popen( [thirdparty_binary("ngrammake"), "--method=kneser_ney"], stdin=count_proc.stdout, stdout=subprocess.PIPE, stderr=log_file, env=os.environ, ) shrink_proc = subprocess.Popen( [ thirdparty_binary("ngramshrink"), "--method=relative_entropy", f"--target_number_of_ngrams={self.target_num_ngrams}", "--shrink_opt=2", "--theta=0.001", "-", mod_path, ], stdin=make_proc.stdout, stderr=log_file, env=os.environ, ) for (text,) in utterances: farcompile_proc.stdin.write(f"{text}\n".encode("utf8")) farcompile_proc.stdin.flush() farcompile_proc.stdin.close() shrink_proc.wait() context_width = self.hclg_options["context_width"] central_pos = self.hclg_options["central_pos"] lg_path = d.temp_directory.joinpath(f"LG.{speaker_id}.fst") hclga_path = d.temp_directory.joinpath(f"HCLGa.{speaker_id}.fst") ilabels_temp = d.temp_directory.joinpath( f"ilabels_{context_width}_{central_pos}.{speaker_id}" ) out_disambig = d.temp_directory.joinpath( f"disambig_ilabels_{context_width}_{central_pos}.int" ) clg_path = d.temp_directory.joinpath( f"CLG_{context_width}_{central_pos}.{speaker_id}.fst" ) log_file.write("Generating LG.fst...") compose_lg(d.lexicon_disambig_fst_path, mod_path, lg_path, log_file) log_file.write("Generating CLG.fst...") compose_clg( d.disambiguation_symbols_int_path, out_disambig, context_width, central_pos, ilabels_temp, lg_path, clg_path, log_file, ) log_file.write("Generating HCLGa.fst...") compose_hclg( self.model_path, ilabels_temp, self.hclg_options["transition_scale"], clg_path, hclga_path, log_file, ) log_file.write("Generating HCLG.fst...") self_loop_proc = subprocess.Popen( [ thirdparty_binary("add-self-loops"), f"--self-loop-scale={self.hclg_options['self_loop_scale']}", "--reorder=true", self.model_path, hclga_path, ], stderr=log_file, stdout=subprocess.PIPE, env=os.environ, ) convert_proc = subprocess.Popen( [ thirdparty_binary("fstconvert"), "--v=100", "--fst_type=const", "-", hclg_path, ], stdin=self_loop_proc.stdout, stderr=log_file, env=os.environ, ) convert_proc.communicate() self.check_call(convert_proc) os.remove(mod_path) os.remove(lg_path) os.remove(clg_path) os.remove(hclga_path) os.remove(ilabels_temp) os.remove(out_disambig) yield os.path.exists(hclg_path)