Source code for montreal_forced_aligner.g2p.phonetisaurus_trainer

from __future__ import annotations

import collections
import logging
import multiprocessing as mp
import os
import queue
import subprocess
import threading
import time
from pathlib import Path
from queue import Queue

import dataclassy
import numpy
import pynini
import pywrapfst
import sqlalchemy
from pynini.lib import rewrite
from tqdm.rich import tqdm

from montreal_forced_aligner import config
from montreal_forced_aligner.abc import MetaDict, TopLevelMfaWorker
from montreal_forced_aligner.data import WordType, WorkflowType
from montreal_forced_aligner.db import (
    Job,
    M2M2Job,
    M2MSymbol,
    Pronunciation,
    Word,
    Word2Job,
    bulk_update,
)
from montreal_forced_aligner.dictionary.multispeaker import MultispeakerDictionaryMixin
from montreal_forced_aligner.exceptions import PhonetisaurusSymbolError
from montreal_forced_aligner.g2p.generator import PyniniValidator
from montreal_forced_aligner.g2p.trainer import G2PTrainer
from montreal_forced_aligner.helper import mfa_open
from montreal_forced_aligner.models import G2PModel
from montreal_forced_aligner.utils import thirdparty_binary

__all__ = ["PhonetisaurusTrainerMixin", "PhonetisaurusTrainer"]

logger = logging.getLogger("mfa")


@dataclassy.dataclass(slots=True)
class MaximizationArguments:
    """Arguments for the MaximizationWorker"""

    db_string: str
    far_path: Path
    penalize_em: bool
    batch_size: int


@dataclassy.dataclass(slots=True)
class ExpectationArguments:
    """Arguments for the ExpectationWorker"""

    db_string: str
    far_path: Path
    batch_size: int


@dataclassy.dataclass(slots=True)
class AlignmentExportArguments:
    """Arguments for the AlignmentExportWorker"""

    log_path: Path
    far_path: Path
    penalize: bool


@dataclassy.dataclass(slots=True)
class NgramCountArguments:
    """Arguments for the NgramCountWorker"""

    log_path: Path
    far_path: Path
    alignment_symbols_path: Path
    order: int


@dataclassy.dataclass(slots=True)
class AlignmentInitArguments:
    """Arguments for the alignment initialization worker"""

    db_string: str
    log_path: Path
    far_path: Path
    deletions: bool
    insertions: bool
    restrict: bool
    output_order: int
    input_order: int
    eps: str
    s1s2_sep: str
    seq_sep: str
    skip: str
    batch_size: int


class AlignmentInitWorker(mp.Process):
    """
    Multiprocessing worker that initializes alignment FSTs for a subset of the data

    Parameters
    ----------
    job_name: int
        Integer ID for the job
    return_queue: :class:`multiprocessing.Queue`
        Queue to return data
    stopped: :class:`~mp.Event`
        Stop check
    finished_adding: :class:`~mp.Event`
        Check for whether the job queue is done
    args: :class:`~montreal_forced_aligner.g2p.phonetisaurus_trainer.AlignmentInitArguments`
        Arguments for initialization
    """

    def __init__(
        self,
        job_name: int,
        return_queue: mp.Queue,
        stopped: mp.Event,
        finished_adding: mp.Event,
        args: AlignmentInitArguments,
    ):
        super().__init__()
        self.job_name = job_name
        self.return_queue = return_queue
        self.stopped = stopped
        self.finished = mp.Event()
        self.finished_adding = finished_adding
        self.deletions = args.deletions
        self.insertions = args.insertions
        self.restrict = args.restrict
        self.output_order = args.output_order
        self.input_order = args.input_order
        self.eps = args.eps
        self.s1s2_sep = args.s1s2_sep
        self.seq_sep = args.seq_sep
        self.skip = args.skip
        self.far_path = args.far_path
        self.sym_path = self.far_path.with_suffix(".syms")
        self.log_path = args.log_path
        self.db_string = args.db_string
        self.batch_size = args.batch_size

    def data_generator(self, session):
        query = (
            session.query(Word.word, Pronunciation.pronunciation)
            .join(Pronunciation.word)
            .join(Word.job)
            .filter(Word2Job.training == True)  # noqa
            .filter(Word2Job.job_id == self.job_name)
        )
        for w, p in query:
            w = list(w)
            p = p.split()
            yield w, p

    def run(self) -> None:
        """Run the function"""

        engine = sqlalchemy.create_engine(
            self.db_string,
            poolclass=sqlalchemy.NullPool,
            pool_reset_on_return=None,
            isolation_level="AUTOCOMMIT",
            logging_name=f"{type(self).__name__}_engine",
        ).execution_options(logging_token=f"{type(self).__name__}_engine")
        try:
            symbol_table = pywrapfst.SymbolTable()
            symbol_table.add_symbol(self.eps)
            valid_output_ngrams = set()
            base_dir = os.path.dirname(self.far_path)
            with mfa_open(os.path.join(base_dir, "output_ngram.ngrams"), "r") as f:
                for line in f:
                    line = line.strip()
                    valid_output_ngrams.add(line)
            valid_input_ngrams = set()
            with mfa_open(os.path.join(base_dir, "input_ngram.ngrams"), "r") as f:
                for line in f:
                    line = line.strip()
                    valid_input_ngrams.add(line)
            count = 0
            data = {}
            with mfa_open(self.log_path, "w") as log_file, sqlalchemy.orm.Session(
                engine
            ) as session:
                far_writer = pywrapfst.FarWriter.create(self.far_path, arc_type="log")
                for current_index, (input, output) in enumerate(self.data_generator(session)):
                    if self.stopped.is_set():
                        continue
                    try:
                        key = f"{current_index:08x}"
                        fst = pynini.Fst(arc_type="log")
                        final_state = ((len(input) + 1) * (len(output) + 1)) - 1

                        for _ in range(final_state + 1):
                            fst.add_state()
                        for i in range(len(input) + 1):
                            for j in range(len(output) + 1):
                                istate = i * (len(output) + 1) + j
                                if self.deletions:
                                    for output_range in range(1, self.output_order + 1):
                                        if j + output_range <= len(output):
                                            subseq_output = output[j : j + output_range]
                                            output_string = self.seq_sep.join(subseq_output)
                                            if (
                                                output_range > 1
                                                and output_string not in valid_output_ngrams
                                            ):
                                                continue
                                            symbol = self.s1s2_sep.join([self.skip, output_string])
                                            ilabel = symbol_table.find(symbol)
                                            if ilabel == pywrapfst.NO_LABEL:
                                                ilabel = symbol_table.add_symbol(symbol)
                                            ostate = i * (len(output) + 1) + (j + output_range)
                                            fst.add_arc(
                                                istate,
                                                pywrapfst.Arc(
                                                    ilabel,
                                                    ilabel,
                                                    pywrapfst.Weight("log", 99.0),
                                                    ostate,
                                                ),
                                            )
                                if self.insertions:
                                    for input_range in range(1, self.input_order + 1):
                                        if i + input_range <= len(input):
                                            subseq_input = input[i : i + input_range]
                                            input_string = self.seq_sep.join(subseq_input)
                                            if (
                                                input_range > 1
                                                and input_string not in valid_input_ngrams
                                            ):
                                                continue
                                            symbol = self.s1s2_sep.join([input_string, self.skip])
                                            ilabel = symbol_table.find(symbol)
                                            if ilabel == pywrapfst.NO_LABEL:
                                                ilabel = symbol_table.add_symbol(symbol)
                                            ostate = (i + input_range) * (len(output) + 1) + j
                                            fst.add_arc(
                                                istate,
                                                pywrapfst.Arc(
                                                    ilabel,
                                                    ilabel,
                                                    pywrapfst.Weight("log", 99.0),
                                                    ostate,
                                                ),
                                            )

                                for input_range in range(1, self.input_order + 1):
                                    for output_range in range(1, self.output_order + 1):
                                        if i + input_range <= len(
                                            input
                                        ) and j + output_range <= len(output):
                                            if (
                                                self.restrict
                                                and input_range > 1
                                                and output_range > 1
                                            ):
                                                continue
                                            subseq_output = output[j : j + output_range]
                                            output_string = self.seq_sep.join(subseq_output)
                                            if (
                                                output_range > 1
                                                and output_string not in valid_output_ngrams
                                            ):
                                                continue
                                            subseq_input = input[i : i + input_range]
                                            input_string = self.seq_sep.join(subseq_input)
                                            if (
                                                input_range > 1
                                                and input_string not in valid_input_ngrams
                                            ):
                                                continue
                                            symbol = self.s1s2_sep.join(
                                                [input_string, output_string]
                                            )
                                            ilabel = symbol_table.find(symbol)
                                            if ilabel == pywrapfst.NO_LABEL:
                                                ilabel = symbol_table.add_symbol(symbol)
                                            ostate = (i + input_range) * (len(output) + 1) + (
                                                j + output_range
                                            )
                                            fst.add_arc(
                                                istate,
                                                pywrapfst.Arc(
                                                    ilabel,
                                                    ilabel,
                                                    pywrapfst.Weight(
                                                        "log", float(input_range * output_range)
                                                    ),
                                                    ostate,
                                                ),
                                            )
                        fst.set_start(0)
                        fst.set_final(final_state, pywrapfst.Weight.one(fst.weight_type()))
                        fst = pynini.connect(fst)
                        for state in fst.states():
                            for arc in fst.arcs(state):
                                sym = symbol_table.find(arc.ilabel)
                                if sym not in data:
                                    data[sym] = arc.weight
                                else:
                                    data[sym] = pywrapfst.plus(data[sym], arc.weight)
                        if count >= self.batch_size:
                            data = {k: float(v) for k, v in data.items()}
                            self.return_queue.put((self.job_name, data, count))
                            data = {}
                            count = 0
                        log_file.flush()
                        far_writer[key] = fst
                        del fst
                        count += 1
                    except Exception as e:  # noqa
                        self.stopped.set()
                        self.return_queue.put(e)
            if data:
                data = {k: float(v) for k, v in data.items()}
                self.return_queue.put((self.job_name, data, count))
            symbol_table.write_text(self.far_path.with_suffix(".syms"))
            return
        except Exception as e:
            self.stopped.set()
            self.return_queue.put(e)
        finally:
            self.finished.set()
            del far_writer


class ExpectationWorker(mp.Process):
    """
    Multiprocessing worker that runs the expectation step of training for a subset of the data

    Parameters
    ----------
    job_name: int
        Integer ID for the job
    return_queue: :class:`multiprocessing.Queue`
        Queue to return data
    stopped: :class:`~mp.Event`
        Stop check
    args: :class:`~montreal_forced_aligner.g2p.phonetisaurus_trainer.ExpectationArguments`
        Arguments for the function
    """

    def __init__(
        self,
        job_name: int,
        return_queue: mp.Queue,
        stopped: mp.Event,
        args: ExpectationArguments,
    ):
        super().__init__()
        self.job_name = job_name
        self.db_string = args.db_string
        self.far_path = args.far_path
        self.batch_size = args.batch_size
        self.return_queue = return_queue
        self.stopped = stopped
        self.finished = mp.Event()

    def run(self) -> None:
        """Run the function"""
        engine = sqlalchemy.create_engine(
            self.db_string,
            poolclass=sqlalchemy.NullPool,
            pool_reset_on_return=None,
            isolation_level="AUTOCOMMIT",
            logging_name=f"{type(self).__name__}_engine",
        ).execution_options(logging_token=f"{type(self).__name__}_engine")
        far_reader = pywrapfst.FarReader.open(self.far_path)
        symbol_table = pywrapfst.SymbolTable.read_text(self.far_path.with_suffix(".syms"))
        symbol_mapper = {}
        data = {}
        count = 0
        with sqlalchemy.orm.Session(engine) as session:
            query = (
                session.query(M2MSymbol.symbol, M2MSymbol.id)
                .join(M2MSymbol.jobs)
                .filter(M2M2Job.job_id == self.job_name)
            )
            for symbol, sym_id in query:
                symbol_mapper[symbol_table.find(symbol)] = sym_id
        while not far_reader.done():
            if self.stopped.is_set():
                break
            fst = far_reader.get_fst()
            zero = pywrapfst.Weight.zero("log")
            try:
                fst = pynini.Fst.read_from_string(fst.write_to_string())
                alpha = pynini.shortestdistance(fst)
                beta = pynini.shortestdistance(fst, reverse=True)
                for state_id in fst.states():
                    for arc in fst.arcs(state_id):
                        gamma = pywrapfst.divide(
                            pywrapfst.times(
                                pywrapfst.times(alpha[state_id], arc.weight), beta[arc.nextstate]
                            ),
                            beta[0],
                        )
                        if float(gamma) != numpy.inf:
                            sym_id = symbol_mapper[arc.ilabel]
                            if sym_id not in data:
                                data[sym_id] = zero
                            data[sym_id] = pywrapfst.plus(data[sym_id], gamma)
                if count >= self.batch_size:
                    data = {k: float(v) for k, v in data.items()}
                    self.return_queue.put((data, count))
                    data = {}
                    count = 0
                next(far_reader)
                del alpha
                del beta
                del fst
                count += 1
            except Exception as e:  # noqa
                self.stopped.set()
                self.return_queue.put(e)
                raise
        if data:
            data = {k: float(v) for k, v in data.items()}
            self.return_queue.put((data, count))
        self.finished.set()
        del far_reader
        return


class MaximizationWorker(mp.Process):
    """
    Multiprocessing worker that runs the maximization step of training for a subset of the data

    Parameters
    ----------
    job_name: int
        Integer ID for the job
    return_queue: :class:`multiprocessing.Queue`
        Queue to return data
    stopped: :class:`~multiprocessing.Event`
        Stop check
    args: :class:`~montreal_forced_aligner.g2p.phonetisaurus_trainer.MaximizationArguments`
        Arguments for maximization
    """

    def __init__(
        self,
        job_name: int,
        return_queue: mp.Queue,
        stopped: mp.Event,
        args: MaximizationArguments,
    ):
        super().__init__()
        self.job_name = job_name
        self.return_queue = return_queue
        self.stopped = stopped
        self.finished = mp.Event()
        self.db_string = args.db_string
        self.penalize_em = args.penalize_em
        self.far_path = args.far_path
        self.batch_size = args.batch_size

    def run(self) -> None:
        """Run the function"""
        symbol_table = pywrapfst.SymbolTable.read_text(self.far_path.with_suffix(".syms"))
        count = 0
        engine = sqlalchemy.create_engine(
            self.db_string,
            poolclass=sqlalchemy.NullPool,
            pool_reset_on_return=None,
            isolation_level="AUTOCOMMIT",
            logging_name=f"{type(self).__name__}_engine",
        ).execution_options(logging_token=f"{type(self).__name__}_engine")
        try:
            alignment_model = {}
            with sqlalchemy.orm.Session(engine) as session:
                query = (
                    session.query(M2MSymbol)
                    .join(M2MSymbol.jobs)
                    .filter(M2M2Job.job_id == self.job_name)
                )
                for m2m in query:
                    weight = pywrapfst.Weight("log", m2m.weight)
                    if self.penalize_em:
                        if m2m.grapheme_order > 1 or m2m.phone_order > 1:
                            weight = pywrapfst.Weight("log", float(weight) * m2m.total_order)
                        if weight == pywrapfst.Weight.zero("log") or float(weight) == numpy.inf:
                            weight = pywrapfst.Weight("log", 99)
                    alignment_model[symbol_table.find(m2m.symbol)] = weight
            far_reader = pywrapfst.FarReader.open(self.far_path)
            far_writer = pywrapfst.FarWriter.create(
                self.far_path.with_suffix(self.far_path.suffix + ".temp"), arc_type="log"
            )
            while not far_reader.done():
                if self.stopped.is_set():
                    break
                key = far_reader.get_key()
                fst = far_reader.get_fst()
                for state_id in fst.states():
                    maiter = fst.mutable_arcs(state_id)
                    while not maiter.done():
                        arc = maiter.value()
                        arc.weight = alignment_model[arc.ilabel]
                        arc = pywrapfst.Arc(arc.ilabel, arc.olabel, arc.weight, arc.nextstate)
                        maiter.set_value(arc)
                        next(maiter)
                    del maiter
                far_writer[key] = fst
                next(far_reader)
                if count >= self.batch_size:
                    self.return_queue.put(count)
                    count = 0
                del fst
                count += 1
            del far_reader
            del far_writer
            os.remove(self.far_path)
            os.rename(self.far_path.with_suffix(self.far_path.suffix + ".temp"), self.far_path)
        except Exception as e:
            self.stopped.set()
            self.return_queue.put(e)
            raise
        finally:
            if count >= 1:
                self.return_queue.put(count)
            self.finished.set()


class AlignmentExporter(mp.Process):
    """
    Multiprocessing worker to generate Ngram counts for aligned FST archives

    Parameters
    ----------
    return_queue: :class:`multiprocessing.Queue`
        Queue to return data
    stopped: :class:`~multiprocessing.Event`
        Stop check
    args: :class:`~montreal_forced_aligner.g2p.phonetisaurus_trainer.AlignmentExportArguments`
        Arguments for maximization
    """

    def __init__(self, return_queue: mp.Queue, stopped: mp.Event, args: AlignmentExportArguments):
        super().__init__()
        self.return_queue = return_queue
        self.stopped = stopped
        self.finished = mp.Event()
        self.penalize = args.penalize
        self.far_path = args.far_path
        self.log_path = args.log_path

    def run(self) -> None:
        """Run the function"""
        symbol_table = pywrapfst.SymbolTable.read_text(self.far_path.with_suffix(".syms"))
        with mfa_open(self.log_path, "w") as log_file:
            far_reader = pywrapfst.FarReader.open(self.far_path)
            one_best_path = self.far_path.with_suffix(".strings")
            no_alignment_count = 0
            total = 0
            with mfa_open(one_best_path, "w") as f:
                while not far_reader.done():
                    fst = far_reader.get_fst()
                    total += 1
                    if fst.num_states() == 0:
                        next(far_reader)
                        no_alignment_count += 1
                        self.return_queue.put(1)
                        continue
                    tfst = pynini.arcmap(
                        pynini.Fst.read_from_string(fst.write_to_string()), map_type="to_std"
                    )
                    if self.penalize:
                        for state in tfst.states():
                            maiter = tfst.mutable_arcs(state)
                            while not maiter.done():
                                arc = maiter.value()
                                sym = symbol_table.find(arc.ilabel)
                                ld = self.penalties[sym]
                                if ld.lhs > 1 and ld.rhs > 1:
                                    arc.weight = pywrapfst.Weight(tfst.weight_type(), 999)
                                else:
                                    arc.weight = pywrapfst.Weight(
                                        tfst.weight_type(), float(arc.weight) * ld.max
                                    )
                                maiter.set_value(arc)
                                next(maiter)
                            del maiter
                    pfst = rewrite.lattice_to_dfa(tfst, True, 8).project("output").rmepsilon()

                    if pfst.start() != pywrapfst.NO_SYMBOL:
                        path = pynini.shortestpath(pfst)
                    else:
                        pfst = rewrite.lattice_to_dfa(tfst, False, 8).project("output").rmepsilon()
                        path = pynini.shortestpath(pfst)
                    string = path.string(symbol_table)
                    f.write(f"{string}\n")
                    log_file.flush()
                    next(far_reader)
                    self.return_queue.put(1)
                    del fst
                    del pfst
                    del path
                    del tfst
            log_file.write(
                f"Done {total - no_alignment_count}, no alignment for {no_alignment_count}"
            )
            log_file.flush()
            self.finished.set()
            del far_reader


class NgramCountWorker(threading.Thread):
    """
    Multiprocessing worker to generate Ngram counts for aligned FST archives

    Parameters
    ----------
    return_queue: :class:`multiprocessing.Queue`
        Queue to return data
    stopped: :class:`~threading.Event`
        Stop check
    args: :class:`~montreal_forced_aligner.g2p.phonetisaurus_trainer.NgramCountArguments`
        Arguments for maximization
    """

    def __init__(self, return_queue: Queue, stopped: threading.Event, args: NgramCountArguments):
        super().__init__()
        self.return_queue = return_queue
        self.stopped = stopped
        self.finished = threading.Event()
        self.order = args.order
        self.far_path = args.far_path
        self.log_path = args.log_path
        self.alignment_symbols_path = args.alignment_symbols_path

    def run(self) -> None:
        """Run the function"""
        with mfa_open(self.log_path, "w") as log_file:
            one_best_path = self.far_path.with_suffix(".strings")
            ngram_count_path = self.far_path.with_suffix(".cnts")
            farcompile_proc = subprocess.Popen(
                [
                    thirdparty_binary("farcompilestrings"),
                    "--fst_type=compact",
                    "--token_type=symbol",
                    f"--symbols={self.alignment_symbols_path}",
                    one_best_path,
                ],
                stderr=log_file,
                stdout=subprocess.PIPE,
                env=os.environ,
            )
            ngramcount_proc = subprocess.Popen(
                [
                    thirdparty_binary("ngramcount"),
                    "--require_symbols=false",
                    "--round_to_int",
                    f"--order={self.order}",
                    "-",
                    ngram_count_path,
                ],
                stderr=log_file,
                stdin=farcompile_proc.stdout,
                env=os.environ,
            )
            ngramcount_proc.communicate()
            self.finished.set()


[docs] class PhonetisaurusTrainerMixin: """ Mixin class for training Phonetisaurus style models Parameters ---------- order: int Order of the ngram model, defaults to 8 batch_size:int Batch size for training, defaults to 1000 num_iterations:int Maximum number of iterations to use in Baum-Welch training, defaults to 10 smoothing_method:str Smoothing method for the ngram model, defaults to "kneser_ney" pruning_method:str Pruning method for pruning the ngram model, defaults to "relative_entropy" model_size: int Target number of ngrams for pruning, defaults to 1000000 initial_prune_threshold: float Pruning threshold for calculating the multiple phone/grapheme strings that are to be allowed, defaults to 0.0001 insertions: bool Flag for whether to allow for insertions, default True deletions: bool Flag for whether to allow for deletions, default True restrict_m2m: bool Flag for whether to restrict possible alignments to one-to-many and disable many-to-many alignments, default False penalize_em: bool Flag for whether to many-to-many and one-to-many are penalized over one-to-one mappings during training, default False penalize: bool Flag for whether to many-to-many and one-to-many are penalized over one-to-one mappings during export, default False sequence_separator: str Character to use for concatenating and aligning multiple phones or graphemes, defaults to "|" skip: str Character to use to represent deletions or insertions, defaults to "_" alignment_separator: str Character to use for concatenating grapheme strings and phone strings, defaults to ";" grapheme_order: int Maximum number of graphemes to map to single phones phone_order: int Maximum number of phones to map to single graphemes em_threshold: float Threshold of minimum change for early stopping of EM training """ alignment_init_function = AlignmentInitWorker def __init__( self, order: int = 8, batch_size: int = 1000, num_iterations: int = 10, smoothing_method: str = "kneser_ney", pruning_method: str = "relative_entropy", model_size: int = 1000000, initial_prune_threshold: float = 0.0001, insertions: bool = True, deletions: bool = True, restrict_m2m: bool = False, penalize_em: bool = False, penalize: bool = False, sequence_separator: str = "|", skip: str = "_", alignment_separator: str = ";", grapheme_order: int = 2, phone_order: int = 2, em_threshold: float = 1e-5, **kwargs, ): super().__init__(**kwargs) if not hasattr(self, "_data_source"): self._data_source = None self.order = order self.batch_size = batch_size self.num_iterations = num_iterations self.smoothing_method = smoothing_method self.pruning_method = pruning_method self.model_size = model_size self.initial_prune_threshold = initial_prune_threshold self.insertions = insertions self.deletions = deletions self.grapheme_order = grapheme_order self.phone_order = phone_order self.input_order = self.grapheme_order self.output_order = self.phone_order self.sequence_separator = sequence_separator self.alignment_separator = alignment_separator self.skip = skip self.eps = "<eps>" self.restrict_m2m = restrict_m2m self.penalize_em = penalize_em self.penalize = penalize self.em_threshold = em_threshold self.g2p_num_training_pronunciations = 0 self.symbol_table = pywrapfst.SymbolTable() self.symbol_table.add_symbol(self.eps) self.total = pywrapfst.Weight.zero("log") self.prev_total = pywrapfst.Weight.zero("log") @property def architecture(self) -> str: """Phonetisaurus""" return "phonetisaurus"
[docs] def initialize_alignments(self) -> None: """ Initialize alignment FSTs for training """ logger.info("Creating alignment FSTs...") return_queue = mp.Queue() stopped = mp.Event() finished_adding = mp.Event() procs = [] for i in range(1, config.NUM_JOBS + 1): args = AlignmentInitArguments( self.db_string, self.working_log_directory.joinpath(f"alignment_init.{i}.log"), self.working_directory.joinpath(f"{i}.far"), self.deletions, self.insertions, self.restrict_m2m, self.phone_order, self.grapheme_order, self.eps, self.alignment_separator, self.sequence_separator, self.skip, self.batch_size, ) procs.append( self.alignment_init_function( i, return_queue, stopped, finished_adding, args, ) ) procs[-1].start() finished_adding.set() error_list = [] symbols = {} job_symbols = {} symbol_id = 1 with tqdm( total=self.g2p_num_training_pronunciations, disable=config.QUIET ) as pbar, self.session() as session: while True: try: result = return_queue.get(timeout=2) if isinstance(result, Exception): error_list.append(result) continue if stopped.is_set(): continue except queue.Empty: for p in procs: if not p.finished.is_set(): break else: break continue job_name, weights, count = result for symbol, weight in weights.items(): weight = pywrapfst.Weight("log", weight) if symbol not in symbols: left_side, right_side = symbol.split(self.alignment_separator) if left_side == self.skip: left_side_order = 0 else: left_side_order = 1 + left_side.count(self.sequence_separator) if right_side == self.skip: right_side_order = 0 else: right_side_order = 1 + right_side.count(self.sequence_separator) max_order = max(left_side_order, right_side_order) total_order = left_side_order + right_side_order symbols[symbol] = { "symbol": symbol, "id": symbol_id, "total_order": total_order, "max_order": max_order, "grapheme_order": left_side_order, "phone_order": right_side_order, "weight": weight, } symbol_id += 1 else: symbols[symbol]["weight"] = pywrapfst.plus( symbols[symbol]["weight"], weight ) self.total = pywrapfst.plus(self.total, weight) if job_name not in job_symbols: job_symbols[job_name] = set() job_symbols[job_name].add(symbols[symbol]["id"]) pbar.update(count) for p in procs: p.join() if error_list: for v in error_list: raise v logger.debug(f"Total of {len(symbols)} symbols, initial total: {self.total}") symbols = [x for x in symbols.values()] for data in symbols: data["weight"] = float(data["weight"]) session.bulk_insert_mappings( M2MSymbol, symbols, return_defaults=False, render_nulls=True ) session.flush() del symbols mappings = [] for j, sym_ids in job_symbols.items(): mappings.extend({"m2m_id": x, "job_id": j} for x in sym_ids) session.bulk_insert_mappings( M2M2Job, mappings, return_defaults=False, render_nulls=True ) session.commit()
[docs] def maximization(self, last_iteration=False) -> float: """ Run the maximization step for training Returns ------- float Current iteration's score """ logger.info("Performing maximization step...") change = abs(float(self.total) - float(self.prev_total)) logger.debug(f"Previous total: {float(self.prev_total)}") logger.debug(f"Current total: {float(self.total)}") logger.debug(f"Change: {change}") self.prev_total = self.total with self.session() as session: session.query(M2MSymbol).update( {"weight": M2MSymbol.weight - float(self.total)}, synchronize_session=False ) session.commit() return_queue = mp.Queue() stopped = mp.Event() procs = [] for i in range(1, config.NUM_JOBS + 1): args = MaximizationArguments( self.db_string, self.working_directory.joinpath(f"{i}.far"), self.penalize_em, self.batch_size, ) procs.append(MaximizationWorker(i, return_queue, stopped, args)) procs[-1].start() error_list = [] with tqdm(total=self.g2p_num_training_pronunciations, disable=config.QUIET) as pbar: while True: try: result = return_queue.get(timeout=1) if isinstance(result, Exception): error_list.append(result) continue if stopped.is_set(): continue except queue.Empty: for p in procs: if not p.finished.is_set(): break else: break continue pbar.update(result) for p in procs: p.join() if error_list: for v in error_list: raise v if not last_iteration and change >= self.em_threshold: # we're still converging self.total = pywrapfst.Weight.zero("log") with self.session() as session: session.query(M2MSymbol).update({"weight": 0.0}) session.commit() logger.info(f"Maximization done! Change from last iteration was {change:.3f}") return change
[docs] def expectation(self) -> None: """ Run the expectation step for training """ logger.info("Performing expectation step...") return_queue = mp.Queue() stopped = mp.Event() error_list = [] procs = [] for i in range(1, config.NUM_JOBS + 1): args = ExpectationArguments( self.db_string, self.working_directory.joinpath(f"{i}.far"), self.batch_size, ) procs.append(ExpectationWorker(i, return_queue, stopped, args)) procs[-1].start() mappings = {} zero = pywrapfst.Weight.zero("log") with tqdm(total=self.g2p_num_training_pronunciations, disable=config.QUIET) as pbar: while True: try: result = return_queue.get(timeout=1) if isinstance(result, Exception): error_list.append(result) continue if stopped.is_set(): continue except queue.Empty: for p in procs: if not p.finished.is_set(): break else: break continue result, count = result for sym_id, gamma in result.items(): gamma = pywrapfst.Weight("log", gamma) if sym_id not in mappings: mappings[sym_id] = zero mappings[sym_id] = pywrapfst.plus(mappings[sym_id], gamma) self.total = pywrapfst.plus(self.total, gamma) pbar.update(count) for p in procs: p.join() if error_list: for v in error_list: raise v with self.session() as session: bulk_update( session, M2MSymbol, [{"id": k, "weight": float(v)} for k, v in mappings.items()] ) session.commit() logger.info("Expectation done!")
[docs] def train_ngram_model(self) -> None: """ Train an ngram model on the aligned FSTs """ if os.path.exists(self.fst_path): logger.info("Ngram model already exists.") return logger.info("Generating ngram counts...") return_queue = Queue() stopped = threading.Event() error_list = [] procs = [] count_paths = [] for i in range(1, config.NUM_JOBS + 1): args = NgramCountArguments( self.working_log_directory.joinpath(f"ngram_count.{i}.log"), self.working_directory.joinpath(f"{i}.far"), self.alignment_symbols_path, self.order, ) procs.append(NgramCountWorker(return_queue, stopped, args)) count_paths.append(args.far_path.with_suffix(".cnts")) procs[-1].start() with tqdm(total=self.g2p_num_training_pronunciations, disable=config.QUIET) as pbar: while True: try: result = return_queue.get(timeout=1) if isinstance(result, Exception): error_list.append(result) continue if stopped.is_set(): continue return_queue.task_done() except queue.Empty: for p in procs: if not p.finished.is_set(): break else: break continue pbar.update(1) for p in procs: p.join() if error_list: for v in error_list: raise v logger.info("Done counting ngrams!") logger.info("Training ngram model...") with mfa_open(self.working_log_directory.joinpath("model.log"), "w") as logf: if len(count_paths) > 1: ngrammerge_proc = subprocess.Popen( [ thirdparty_binary("ngrammerge"), f'--ofile={self.ngram_path.with_suffix(".cnts")}', *count_paths, ], stderr=logf, # stdout=subprocess.PIPE, env=os.environ, ) ngrammerge_proc.communicate() else: os.rename(count_paths[0], self.ngram_path.with_suffix(".cnts")) ngrammake_proc = subprocess.Popen( [ thirdparty_binary("ngrammake"), f"--method={self.smoothing_method}", self.ngram_path.with_suffix(".cnts"), ], stderr=logf, stdout=subprocess.PIPE, env=os.environ, ) ngramshrink_proc = subprocess.Popen( [ thirdparty_binary("ngramshrink"), f"--method={self.pruning_method}", f"--target_number_of_ngrams={self.model_size}", "-", self.ngram_path, ], stdin=ngrammake_proc.stdout, stderr=logf, env=os.environ, ) ngramshrink_proc.communicate() ngram_fst = pynini.Fst.read(self.ngram_path) grapheme_symbols = pywrapfst.SymbolTable() grapheme_symbols.add_symbol(self.eps) grapheme_symbols.add_symbol(self.sequence_separator) grapheme_symbols.add_symbol(self.skip) phone_symbols = pywrapfst.SymbolTable() phone_symbols.add_symbol(self.eps) phone_symbols.add_symbol(self.sequence_separator) phone_symbols.add_symbol(self.skip) single_phone_symbols = pywrapfst.SymbolTable() single_phone_symbols.add_symbol(self.eps) single_phone_fst = pynini.Fst() start_state = single_phone_fst.add_state() single_phone_fst.set_start(start_state) one = pywrapfst.Weight.one(single_phone_fst.weight_type()) single_phone_fst.set_final(start_state, one) current_ind = 1 for state in ngram_fst.states(): maiter = ngram_fst.mutable_arcs(state) while not maiter.done(): arc = maiter.value() symbol = self.symbol_table.find(arc.ilabel) try: grapheme, phone = symbol.split(self.alignment_separator) if grapheme == self.skip: g_symbol = grapheme_symbols.find(self.eps) else: g_symbol = grapheme_symbols.find(grapheme) if g_symbol == pywrapfst.NO_SYMBOL: g_symbol = grapheme_symbols.add_symbol(grapheme) if phone == self.skip: p_symbol = phone_symbols.find(self.eps) else: p_symbol = phone_symbols.find(phone) if p_symbol == pywrapfst.NO_SYMBOL: p_symbol = phone_symbols.add_symbol(phone) singles = phone.split(self.sequence_separator) for i, s in enumerate(singles): s_symbol = single_phone_symbols.find(s) if s_symbol == pywrapfst.NO_SYMBOL: s_symbol = single_phone_symbols.add_symbol(s) if i == 0: single_start = start_state else: single_start = current_ind if i < len(singles) - 1: current_ind = single_phone_fst.add_state() end_state = current_ind else: end_state = start_state single_phone_fst.add_arc( single_start, pywrapfst.Arc( p_symbol if i == 0 else 0, s_symbol, one, end_state ), ) arc = pywrapfst.Arc(g_symbol, p_symbol, arc.weight, arc.nextstate) maiter.set_value(arc) except ValueError: if symbol in {"<eps>", "<unk>", "<epsilon>"}: arc = pywrapfst.Arc(0, 0, arc.weight, arc.nextstate) maiter.set_value(arc) else: raise finally: next(maiter) for i in range(grapheme_symbols.num_symbols()): sym = grapheme_symbols.find(i) if sym in {self.eps, self.sequence_separator, self.skip}: continue parts = sym.split(self.sequence_separator) if len(parts) > 1: for s in parts: if grapheme_symbols.find(s) == pywrapfst.NO_SYMBOL: k = grapheme_symbols.add_symbol(s) ngram_fst.add_arc(1, pywrapfst.Arc(k, 2, 99, 1)) for i in range(phone_symbols.num_symbols()): sym = phone_symbols.find(i) if sym in {self.eps, self.sequence_separator, self.skip}: continue parts = sym.split(self.sequence_separator) if len(parts) > 1: for s in parts: if phone_symbols.find(s) == pywrapfst.NO_SYMBOL: k = phone_symbols.add_symbol(s) ngram_fst.add_arc(1, pywrapfst.Arc(2, k, 99, 1)) single_phone_fst.set_input_symbols(phone_symbols) single_phone_fst.set_output_symbols(single_phone_symbols) ngram_fst.set_input_symbols(grapheme_symbols) ngram_fst.set_output_symbols(phone_symbols) single_ngram_fst = pynini.compose(ngram_fst, single_phone_fst) single_ngram_fst.set_input_symbols(grapheme_symbols) single_ngram_fst.set_output_symbols(single_phone_symbols) grapheme_symbols.write_text(self.grapheme_symbols_path) single_phone_symbols.write_text(self.phone_symbols_path) single_ngram_fst.write(self.fst_path)
[docs] def train_alignments(self) -> None: """ Run an Expectation-Maximization (EM) training on alignment FSTs to generate well-aligned FSTs for ngram modeling """ if os.path.exists(self.alignment_model_path): logger.info("Using existing alignments.") self.symbol_table = pywrapfst.SymbolTable.read_text(self.alignment_symbols_path) return self.initialize_alignments() self.maximization() logger.info("Training alignments...") for i in range(self.num_iterations): logger.info(f"Iteration {i}") self.expectation() change = self.maximization(last_iteration=i == self.num_iterations - 1) if change < self.em_threshold: break
@property def data_directory(self) -> Path: """Data directory for trainer""" return self.working_directory
[docs] def train_iteration(self) -> None: """Train iteration, not used""" pass
@property def data_source_identifier(self) -> str: """Dictionary name""" return self._data_source
[docs] def export_model(self, output_model_path: Path) -> None: """ Export G2P model to specified path Parameters ---------- output_model_path: :class:`~pathlib.Path` Path to export model """ directory = output_model_path.parent directory.mkdir(parents=True, exist_ok=True) models_temp_dir = self.working_directory.joinpath("model_archive_temp") model = G2PModel.empty(output_model_path.stem, root_directory=models_temp_dir) model.add_meta_file(self) model.add_fst_model(self.working_directory) model.add_sym_path(self.working_directory) if directory: os.makedirs(directory, exist_ok=True) model.dump(output_model_path) model.clean_up() logger.info(f"Saved model to {output_model_path}")
@property def alignment_model_path(self) -> Path: """Path to store alignment model FST""" return self.working_directory.joinpath("align.fst") @property def ngram_path(self) -> Path: """Path to store ngram model""" return self.working_directory.joinpath("ngram.fst") @property def fst_path(self) -> Path: """Path to store final trained model""" return self.working_directory.joinpath("model.fst") @property def alignment_symbols_path(self) -> Path: """Path to alignment symbol table""" return self.working_directory.joinpath("alignment.syms") @property def grapheme_symbols_path(self) -> Path: """Path to final model's grapheme symbol table""" return self.working_directory.joinpath("graphemes.txt") @property def phone_symbols_path(self) -> Path: """Path to final model's phone symbol table""" return self.working_directory.joinpath("phones.txt") @property def far_path(self) -> Path: """Path to store final aligned FSTs""" return self.working_directory.joinpath("aligned.far")
[docs] def export_alignments(self) -> None: """ Combine alignment training archives to a final combined FST archive to train the ngram model """ logger.info("Exporting final alignments...") return_queue = mp.Queue() stopped = mp.Event() error_list = [] procs = [] count_paths = [] for i in range(1, config.NUM_JOBS + 1): args = AlignmentExportArguments( self.working_log_directory.joinpath(f"ngram_count.{i}.log"), self.working_directory.joinpath(f"{i}.far"), self.penalize, ) procs.append(AlignmentExporter(return_queue, stopped, args)) count_paths.append(args.far_path.with_suffix(".cnts")) procs[-1].start() with tqdm(total=self.g2p_num_training_pronunciations, disable=config.QUIET) as pbar: while True: try: result = return_queue.get(timeout=1) if isinstance(result, Exception): error_list.append(result) continue if stopped.is_set(): continue except queue.Empty: for p in procs: if not p.finished.is_set(): break else: break continue pbar.update(1) for p in procs: p.join() if error_list: for v in error_list: raise v with mfa_open(self.working_log_directory.joinpath("symbols.log"), "w") as log_file: symbols_proc = subprocess.Popen( [ thirdparty_binary("ngramsymbols"), "--OOV_symbol=<unk>", "--epsilon_symbol=<eps>", "-", self.alignment_symbols_path, ], encoding="utf8", stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=log_file, ) for j in range(1, config.NUM_JOBS + 1): text_path = self.working_directory.joinpath(f"{j}.strings") with mfa_open(text_path, "r") as f: for line in f: symbols_proc.stdin.write(line) symbols_proc.stdin.flush() symbols_proc.stdin.close() symbols_proc.wait() self.symbol_table = pywrapfst.SymbolTable.read_text(self.alignment_symbols_path) logger.info("Done exporting alignments!")
def compute_initial_ngrams(self) -> None: input_path = self.working_directory.joinpath("input.txt") input_ngram_path = self.working_directory.joinpath("input_ngram.fst") if input_ngram_path.with_suffix(".ngrams").exists(): logger.info("Initial ngrams already computed") return logger.info("Computing initial ngrams...") with mfa_open(self.working_log_directory.joinpath("initial_ngrams.log"), "w") as log_file: input_symbols_path = self.working_directory.joinpath("input_ngram.syms") subprocess.check_call( [ thirdparty_binary("ngramsymbols"), "--OOV_symbol=<unk>", "--epsilon_symbol=<eps>", input_path, input_symbols_path, ], encoding="utf8", stderr=log_file, ) farcompile_proc = subprocess.Popen( [ thirdparty_binary("farcompilestrings"), "--fst_type=compact", "--token_type=symbol", f"--symbols={input_symbols_path}", input_path, ], env=os.environ, stderr=log_file, stdout=subprocess.PIPE, ) ngramcount_proc = subprocess.Popen( [ thirdparty_binary("ngramcount"), "--require_symbols=false", "--round_to_int", f"--order={self.input_order}", ], stdin=farcompile_proc.stdout, stdout=subprocess.PIPE, env=os.environ, stderr=log_file, ) ngrammake_proc = subprocess.Popen( [ thirdparty_binary("ngrammake"), f"--method={self.smoothing_method}", ], stdin=ngramcount_proc.stdout, stdout=subprocess.PIPE, env=os.environ, stderr=log_file, ) ngramshrink_proc = subprocess.Popen( [ thirdparty_binary("ngramshrink"), f"--method={self.pruning_method}", f"--theta={self.initial_prune_threshold}", ], stdin=ngrammake_proc.stdout, stdout=subprocess.PIPE, env=os.environ, stderr=log_file, ) print_proc = subprocess.Popen( [ thirdparty_binary("ngramprint"), f"--symbols={input_symbols_path}", ], env=os.environ, stdin=ngramshrink_proc.stdout, stdout=subprocess.PIPE, encoding="utf8", stderr=log_file, ) ngrams = set() for line in print_proc.stdout: line = line.strip().split()[:-1] ngram = self.sequence_separator.join(x for x in line if x not in {"<s>", "</s>"}) if self.sequence_separator not in ngram: continue ngrams.add(ngram) print_proc.wait() with mfa_open(input_ngram_path.with_suffix(".ngrams"), "w") as f: for ngram in sorted(ngrams): f.write(f"{ngram}\n") output_path = self.working_directory.joinpath("output.txt") output_ngram_path = self.working_directory.joinpath("output_ngram.fst") output_symbols_path = self.working_directory.joinpath("output_ngram.syms") symbols_proc = subprocess.Popen( [ thirdparty_binary("ngramsymbols"), "--OOV_symbol=<unk>", "--epsilon_symbol=<eps>", output_path, output_symbols_path, ], encoding="utf8", stderr=log_file, ) symbols_proc.communicate() farcompile_proc = subprocess.Popen( [ thirdparty_binary("farcompilestrings"), "--fst_type=compact", "--token_type=symbol", f"--symbols={output_symbols_path}", output_path, ], stdout=subprocess.PIPE, env=os.environ, stderr=log_file, ) ngramcount_proc = subprocess.Popen( [ thirdparty_binary("ngramcount"), "--require_symbols=false", "--round_to_int", f"--order={self.output_order}", ], stdin=farcompile_proc.stdout, stdout=subprocess.PIPE, env=os.environ, stderr=log_file, ) ngrammake_proc = subprocess.Popen( [ thirdparty_binary("ngrammake"), f"--method={self.smoothing_method}", ], stdin=ngramcount_proc.stdout, stdout=subprocess.PIPE, env=os.environ, stderr=log_file, ) ngramshrink_proc = subprocess.Popen( [ thirdparty_binary("ngramshrink"), f"--method={self.pruning_method}", f"--theta={self.initial_prune_threshold}", ], stdin=ngrammake_proc.stdout, stdout=subprocess.PIPE, env=os.environ, stderr=log_file, ) print_proc = subprocess.Popen( [thirdparty_binary("ngramprint"), f"--symbols={output_symbols_path}"], env=os.environ, stdin=ngramshrink_proc.stdout, stdout=subprocess.PIPE, encoding="utf8", stderr=log_file, ) ngrams = set() for line in print_proc.stdout: line = line.strip().split()[:-1] ngram = self.sequence_separator.join(x for x in line if x not in {"<s>", "</s>"}) if self.sequence_separator not in ngram: continue ngrams.add(ngram) print_proc.wait() with mfa_open(output_ngram_path.with_suffix(".ngrams"), "w") as f: for ngram in sorted(ngrams): f.write(f"{ngram}\n")
[docs] def train(self) -> None: """ Train a G2P model """ if os.path.exists(self.fst_path): return os.makedirs(self.working_log_directory, exist_ok=True) begin = time.time() self.train_alignments() logger.debug( f"Aligning {len(self.g2p_training_dictionary)} words took {time.time() - begin:.3f} seconds" ) self.export_alignments() begin = time.time() self.train_ngram_model() logger.debug( f"Generating model for {len(self.g2p_training_dictionary)} words took {time.time() - begin:.3f} seconds" ) self.finalize_training()
[docs] class PhonetisaurusTrainer( MultispeakerDictionaryMixin, PhonetisaurusTrainerMixin, G2PTrainer, TopLevelMfaWorker ): """ Top level trainer class for Phonetisaurus-style models """ def __init__( self, **kwargs, ): self._data_source = kwargs["dictionary_path"].stem super().__init__(**kwargs) self.ler = None self.wer = None @property def data_directory(self) -> Path: """Data directory for trainer""" return self.working_directory @property def configuration(self) -> MetaDict: """Configuration for G2P trainer""" config = super().configuration config.update({"dictionary_path": str(self.dictionary_model.path)}) return config
[docs] def setup(self) -> None: """Setup for G2P training""" super().setup() self.create_new_current_workflow(WorkflowType.train_g2p) wf = self.current_workflow if wf.done: logger.info("G2P training already done, skipping.") return self.dictionary_setup() os.makedirs(self.phones_dir, exist_ok=True) self.initialize_training() self.initialized = True
[docs] def finalize_training(self) -> None: """Finalize training and run evaluation if specified""" if self.evaluation_mode: self.evaluate_g2p_model()
@property def meta(self) -> MetaDict: """Metadata for exported G2P model""" from datetime import datetime from ..utils import get_mfa_version meta = { "version": get_mfa_version(), "architecture": self.architecture, "train_date": str(datetime.now()), "phones": sorted(self.non_silence_phones), "graphemes": self.g2p_training_graphemes, "grapheme_order": self.grapheme_order, "phone_order": self.phone_order, "sequence_separator": self.sequence_separator, "evaluation": {}, "training": { "num_words": self.g2p_num_training_words, "num_graphemes": len(self.g2p_training_graphemes), "num_phones": len(self.non_silence_phones), }, } if self.model_version is not None: meta["version"] = self.model_version if self.evaluation_mode: meta["evaluation"]["num_words"] = self.g2p_num_validation_words meta["evaluation"]["word_error_rate"] = self.wer meta["evaluation"]["phone_error_rate"] = self.ler return meta
[docs] def evaluate_g2p_model(self) -> None: """ Validate the G2P model against held out data """ temp_model_path = self.working_log_directory.joinpath("g2p_model.zip") self.export_model(temp_model_path) temp_dir = self.working_directory.joinpath("validation") os.makedirs(temp_dir, exist_ok=True) with self.session() as session: validation_set = collections.defaultdict(set) query = ( session.query(Word.word, Pronunciation.pronunciation) .join(Pronunciation.word) .join(Word.job) .filter(Word2Job.training == False) # noqa ) for w, pron in query: validation_set[w].add(pron) gen = PyniniValidator( g2p_model_path=temp_model_path, word_list=list(validation_set.keys()), num_pronunciations=self.num_pronunciations, ) output = gen.generate_pronunciations() with mfa_open(temp_dir.joinpath("validation_output.txt"), "w") as f: for orthography, pronunciations in output.items(): if not pronunciations: continue for p in pronunciations: if not p: continue f.write(f"{orthography}\t{p}\n") gen.compute_validation_errors(validation_set, output)
[docs] def initialize_training(self) -> None: """Initialize training G2P model""" with self.session() as session: session.query(Word2Job).delete() session.query(M2M2Job).delete() session.query(M2MSymbol).delete() session.query(Job).delete() session.commit() job_objs = [{"id": j} for j in range(1, config.NUM_JOBS + 1)] self.g2p_num_training_pronunciations = 0 self.g2p_num_validation_pronunciations = 0 self.g2p_num_training_words = 0 self.g2p_num_validation_words = 0 # Below we partition sorted list of words to try to have each process handling different symbol tables # so they're not completely overlapping and using more memory num_words = session.query(Word.id).count() words_per_job = int(num_words / config.NUM_JOBS) + 1 current_job = 1 words = session.query(Word.id).filter(Word.word_type.in_(WordType.speech_types())) mappings = [] for i, (w,) in enumerate(words): if i >= (current_job) * words_per_job and current_job != config.NUM_JOBS + 1: current_job += 1 mappings.append({"word_id": w, "job_id": current_job, "training": 1}) session.bulk_insert_mappings(Job, job_objs) session.flush() session.execute(sqlalchemy.insert(Word2Job.__table__), mappings) session.commit() if self.evaluation_mode: validation_items = int(num_words * self.validation_proportion) validation_words = ( sqlalchemy.select(Word.id) .order_by(sqlalchemy.func.random()) .limit(validation_items) .scalar_subquery() ) query = ( sqlalchemy.update(Word2Job) .execution_options(synchronize_session="fetch") .values(training=False) .where(Word2Job.word_id.in_(validation_words)) ) with session.begin_nested(): session.execute(query) session.flush() session.commit() query = ( session.query(Word.word, Pronunciation.pronunciation) .join(Pronunciation.word) .join(Word.job) .filter(Word2Job.training == False) # noqa ) for word, pronunciation in query: self.g2p_validation_graphemes.update(word) self.g2p_validation_phones.update(pronunciation.split()) self.g2p_num_validation_pronunciations += 1 query = ( session.query(Pronunciation.pronunciation, Word.word) .join(Pronunciation.word) .join(Word.job) .filter(Word2Job.training == True) # noqa ) for pronunciation, word in query: word = list(word) self.g2p_training_graphemes.update(word) self.g2p_training_phones.update(pronunciation.split()) grapheme_diff = sorted(self.g2p_validation_graphemes - self.g2p_training_graphemes) phone_diff = sorted(self.g2p_validation_phones - self.g2p_training_phones) if grapheme_diff: low_count_grapheme_pattern = rf'[{"".join(grapheme_diff)}]' query = session.query(Word.id).filter( Word.word.op("~")(low_count_grapheme_pattern) if config.USE_POSTGRES else Word.word.regexp_match(low_count_grapheme_pattern) ) logger.debug( f'Adding {query.count()} low grapheme count words to training data for: {", ".join(grapheme_diff)}' ) query = ( sqlalchemy.update(Word2Job) .execution_options(synchronize_session="fetch") .values(training=True) .where(Word2Job.word_id.in_(query.subquery())) ) with session.begin_nested(): session.execute(query) session.commit() self.g2p_training_graphemes.update(grapheme_diff) if phone_diff: low_count_phone_pattern = rf'\b({"|".join(phone_diff)})\b' query = ( session.query(Pronunciation.word_id) .filter( Pronunciation.pronunciation.op("~")(low_count_phone_pattern) if config.USE_POSTGRES else Pronunciation.pronunciation.regexp_match(low_count_phone_pattern) ) .distinct() ) logger.debug( f'Adding {query.count()} low phone count words to training data for: {", ".join(phone_diff)}' ) query = ( sqlalchemy.update(Word2Job) .execution_options(synchronize_session="fetch") .values(training=True) .where(Word2Job.word_id.in_(query.subquery())) ) with session.begin_nested(): session.execute(query) session.commit() self.g2p_training_phones.update(phone_diff) self.g2p_num_validation_words = ( session.query(Word2Job.word_id) .filter(Word2Job.training == False) # noqa .count() ) grapheme_count = 0 phone_count = 0 self.character_sets = set() query = ( session.query(Pronunciation.pronunciation, Word.word) .join(Pronunciation.word) .join(Word.job) .filter(Word2Job.training == True) # noqa ) with mfa_open(self.working_directory.joinpath("input.txt"), "w") as word_f, mfa_open( self.working_directory.joinpath("output.txt"), "w" ) as phone_f: for pronunciation, word in query: word = list(word) grapheme_count += len(word) self.g2p_training_graphemes.update(word) self.g2p_num_training_pronunciations += 1 self.g2p_training_phones.update(pronunciation.split()) phone_count += len(pronunciation.split()) word_f.write(" ".join(word) + "\n") phone_f.write(pronunciation + "\n") self.g2p_num_training_words = ( session.query(Word2Job.word_id).filter(Word2Job.training == True).count() # noqa ) logger.debug(f"Graphemes in training data: {sorted(self.g2p_training_graphemes)}") logger.debug(f"Phones in training data: {sorted(self.g2p_training_phones)}") logger.debug(f"Averages phones per grapheme: {phone_count / grapheme_count}") if self.sequence_separator in self.g2p_training_phones | self.g2p_training_graphemes: raise PhonetisaurusSymbolError(self.sequence_separator, "sequence_separator") if self.skip in self.g2p_training_phones | self.g2p_training_graphemes: raise PhonetisaurusSymbolError(self.skip, "skip") if self.alignment_separator in self.g2p_training_phones | self.g2p_training_graphemes: raise PhonetisaurusSymbolError(self.alignment_separator, "alignment_separator") if self.evaluation_mode: logger.debug( f"Graphemes in validation data: {sorted(self.g2p_validation_graphemes)}" ) logger.debug(f"Phones in validation data: {sorted(self.g2p_validation_phones)}") self.compute_initial_ngrams()