Source code for montreal_forced_aligner.language_modeling.multiprocessing

"""Multiprocessing functions for training language models"""
from __future__ import annotations

import os
import subprocess
import typing
from pathlib import Path

import sqlalchemy
from _kalpy.fstext import VectorFst
from kalpy.decoder.decode_graph import DecodeGraphCompiler
from kalpy.fstext.lexicon import LexiconCompiler
from sqlalchemy.orm import joinedload, subqueryload

from montreal_forced_aligner import config
from montreal_forced_aligner.abc import KaldiFunction
from montreal_forced_aligner.data import MfaArguments, WordType
from montreal_forced_aligner.db import Job, Phone, PhoneInterval, Speaker, Utterance, Word
from montreal_forced_aligner.helper import mfa_open
from montreal_forced_aligner.utils import thirdparty_binary

if typing.TYPE_CHECKING:
    from dataclasses import dataclass

    from montreal_forced_aligner.abc import MetaDict
else:
    from dataclassy import dataclass


[docs] @dataclass class TrainSpeakerLmArguments(MfaArguments): """ Arguments for :class:`~montreal_forced_aligner.language_modeling.multiprocessing.TrainSpeakerLmFunction` Parameters ---------- job_name: int Integer ID of the job db_string: str String for database connections log_path: :class:`~pathlib.Path` Path to save logging information during the run model_path: :class:`~pathlib.Path` Path to model order: int Ngram order of the language models method: str Ngram smoothing method target_num_ngrams: int Target number of ngrams hclg_options: dict[str, Any] HCLG creation options """ model_path: Path tree_path: Path lexicon_compilers: typing.Dict[int, LexiconCompiler] order: int method: str target_num_ngrams: int hclg_options: MetaDict
@dataclass class TrainLmArguments(MfaArguments): """ Arguments for :class:`~montreal_forced_aligner.language_modeling.multiprocessing.TrainSpeakerLmFunction` Parameters ---------- job_name: int Integer ID of the job db_string: str String for database connections log_path: :class:`~pathlib.Path` Path to save logging information during the run symbols_path: :class:`~pathlib.Path` Words symbol table paths oov_word: str OOV word order: int Ngram order of the language models """ working_directory: Path symbols_path: Path order: int oov_word: str class TrainLmFunction(KaldiFunction): """ Multiprocessing function to training small language models for each speaker See Also -------- :openfst_src:`farcompilestrings` Relevant OpenFst binary :ngram_src:`ngramcount` Relevant OpenGrm-Ngram binary :ngram_src:`ngrammake` Relevant OpenGrm-Ngram binary :ngram_src:`ngramshrink` Relevant OpenGrm-Ngram binary Parameters ---------- args: :class:`~montreal_forced_aligner.language_modeling.multiprocessing.TrainSpeakerLmArguments` Arguments for the function """ def __init__(self, args: TrainLmArguments): super().__init__(args) self.working_directory = args.working_directory self.symbols_path = args.symbols_path self.order = args.order self.oov_word = args.oov_word def _run(self) -> typing.Generator[bool]: """Run the function""" with self.session() as session, mfa_open(self.log_path, "w") as log_file: word_query = session.query(Word.word).filter( Word.word_type.in_(WordType.speech_types()) ) included_words = set(x[0] for x in word_query) utterance_query = session.query(Utterance.normalized_text, Utterance.text).filter( Utterance.job_id == self.job_name ) farcompile_proc = subprocess.Popen( [ thirdparty_binary("farcompilestrings"), "--fst_type=compact", "--token_type=symbol", "--generate_keys=16", "--keep_symbols", f"--symbols={self.symbols_path}", ], stderr=log_file, stdin=subprocess.PIPE, stdout=subprocess.PIPE, env=os.environ, ) ngramcount_proc = subprocess.Popen( [ thirdparty_binary("ngramcount"), "--round_to_int", f"--order={self.order}", "-", self.working_directory.joinpath(f"{self.job_name}.cnts"), ], stderr=log_file, stdin=farcompile_proc.stdout, env=os.environ, ) for normalized_text, text in utterance_query: if not normalized_text: normalized_text = text text = " ".join( x if x in included_words else self.oov_word for x in normalized_text.split() ) farcompile_proc.stdin.write(f"{text}\n".encode("utf8")) farcompile_proc.stdin.flush() self.callback(1) farcompile_proc.stdin.close() self.check_call(ngramcount_proc) class TrainPhoneLmFunction(KaldiFunction): """ Multiprocessing function to training small language models for each speaker See Also -------- :openfst_src:`farcompilestrings` Relevant OpenFst binary :ngram_src:`ngramcount` Relevant OpenGrm-Ngram binary :ngram_src:`ngrammake` Relevant OpenGrm-Ngram binary :ngram_src:`ngramshrink` Relevant OpenGrm-Ngram binary Parameters ---------- args: :class:`~montreal_forced_aligner.language_modeling.multiprocessing.TrainSpeakerLmArguments` Arguments for the function """ def __init__(self, args: TrainLmArguments): super().__init__(args) self.working_directory = args.working_directory self.symbols_path = args.symbols_path self.order = args.order def _run(self) -> typing.Generator[bool]: """Run the function""" with self.session() as session, mfa_open(self.log_path, "w") as log_file: if config.USE_POSTGRES: string_agg_function = sqlalchemy.func.string_agg else: string_agg_function = sqlalchemy.func.group_concat pronunciation_query = ( sqlalchemy.select(Utterance.id, string_agg_function(Phone.kaldi_label, " ")) .select_from(Utterance) .join(Utterance.phone_intervals) .join(PhoneInterval.phone) .where(Utterance.job_id == self.job_name) .group_by(Utterance.id) ) farcompile_proc = subprocess.Popen( [ thirdparty_binary("farcompilestrings"), "--fst_type=compact", "--token_type=symbol", "--generate_keys=16", f"--symbols={self.symbols_path}", ], stderr=log_file, stdin=subprocess.PIPE, stdout=subprocess.PIPE, env=os.environ, ) ngramcount_proc = subprocess.Popen( [ thirdparty_binary("ngramcount"), "--require_symbols=false", "--round_to_int", f"--order={self.order}", "-", self.working_directory.joinpath(f"{self.job_name}.cnts"), ], stderr=log_file, stdin=farcompile_proc.stdout, env=os.environ, ) for utt_id, phones in session.execute(pronunciation_query): farcompile_proc.stdin.write(f"{phones}\n".encode("utf8")) farcompile_proc.stdin.flush() self.callback((utt_id, phones)) farcompile_proc.stdin.close() self.check_call(ngramcount_proc)
[docs] class TrainSpeakerLmFunction(KaldiFunction): """ Multiprocessing function to training small language models for each speaker See Also -------- :openfst_src:`farcompilestrings` Relevant OpenFst binary :ngram_src:`ngramcount` Relevant OpenGrm-Ngram binary :ngram_src:`ngrammake` Relevant OpenGrm-Ngram binary :ngram_src:`ngramshrink` Relevant OpenGrm-Ngram binary Parameters ---------- args: :class:`~montreal_forced_aligner.language_modeling.multiprocessing.TrainSpeakerLmArguments` Arguments for the function """ def __init__(self, args: TrainSpeakerLmArguments): super().__init__(args) self.model_path = args.model_path self.tree_path = args.tree_path self.lexicon_compilers = args.lexicon_compilers self.order = args.order self.method = args.method self.target_num_ngrams = args.target_num_ngrams self.hclg_options = args.hclg_options def _run(self) -> typing.Generator[bool]: """Run the function""" with self.session() as session, mfa_open(self.log_path, "w") as log_file: job: Job = ( session.query(Job) .options(joinedload(Job.corpus, innerjoin=True), subqueryload(Job.dictionaries)) .filter(Job.id == self.job_name) .first() ) for d in job.dictionaries: dict_id = d.id word_symbols_path = d.words_symbol_path speakers = ( session.query(Speaker.id) .join(Utterance.speaker) .filter(Utterance.job_id == job.id) .filter(Speaker.dictionary_id == dict_id) .distinct() ) for (speaker_id,) in speakers: print(speaker_id) hclg_path = d.temp_directory.joinpath(f"{speaker_id}.fst") if os.path.exists(hclg_path): continue utterances = ( session.query(Utterance.normalized_text) .filter(Utterance.speaker_id == speaker_id) .order_by(Utterance.kaldi_id) ) mod_path = d.temp_directory.joinpath(f"g.{speaker_id}.fst") farcompile_proc = subprocess.Popen( [ thirdparty_binary("farcompilestrings"), "--fst_type=compact", f"--unknown_symbol={d.oov_word}", f"--symbols={word_symbols_path}", "--keep_symbols", "--generate_keys=16", ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=log_file, env=os.environ, ) count_proc = subprocess.Popen( [thirdparty_binary("ngramcount"), f"--order={self.order}"], stdin=farcompile_proc.stdout, stdout=subprocess.PIPE, stderr=log_file, ) make_proc = subprocess.Popen( [thirdparty_binary("ngrammake"), "--method=kneser_ney"], stdin=count_proc.stdout, stdout=subprocess.PIPE, stderr=log_file, env=os.environ, ) shrink_proc = subprocess.Popen( [ thirdparty_binary("ngramshrink"), "--method=relative_entropy", f"--target_number_of_ngrams={self.target_num_ngrams}", "--shrink_opt=2", "--theta=0.001", "-", mod_path, ], stdin=make_proc.stdout, stderr=log_file, env=os.environ, ) for (text,) in utterances: farcompile_proc.stdin.write(f"{text}\n".encode("utf8")) farcompile_proc.stdin.flush() farcompile_proc.stdin.close() shrink_proc.wait() compiler = DecodeGraphCompiler( self.model_path, self.tree_path, self.lexicon_compilers[dict_id], **self.hclg_options, ) compiler.g_fst = VectorFst.Read(str(mod_path)) compiler.export_hclg("", hclg_path) self.callback(os.path.exists(hclg_path))