Source code for montreal_forced_aligner.tokenization.trainer

"""Classes for training tokenizers"""
import collections
import logging
import os
import shutil
import subprocess
import time
from pathlib import Path

import pynini
import pywrapfst
import sqlalchemy

from montreal_forced_aligner import config
from montreal_forced_aligner.abc import MetaDict, TopLevelMfaWorker
from montreal_forced_aligner.corpus.acoustic_corpus import AcousticCorpusMixin
from montreal_forced_aligner.data import WorkflowType
from montreal_forced_aligner.db import M2M2Job, M2MSymbol, Utterance
from montreal_forced_aligner.dictionary.mixins import DictionaryMixin
from montreal_forced_aligner.exceptions import KaldiProcessingError
from montreal_forced_aligner.g2p.phonetisaurus_trainer import (
    AlignmentInitWorker,
    PhonetisaurusTrainerMixin,
)
from montreal_forced_aligner.g2p.trainer import G2PTrainer, PyniniTrainerMixin
from montreal_forced_aligner.helper import mfa_open
from montreal_forced_aligner.models import TokenizerModel
from montreal_forced_aligner.tokenization.tokenizer import TokenizerValidator
from montreal_forced_aligner.utils import log_kaldi_errors, thirdparty_binary

__all__ = ["TokenizerTrainer"]

logger = logging.getLogger("mfa")


class TokenizerAlignmentInitWorker(AlignmentInitWorker):
    """
    Multiprocessing worker that initializes alignment FSTs for a subset of the data

    Parameters
    ----------
    job_name: int
        Integer ID for the job
    return_queue: :class:`multiprocessing.Queue`
        Queue to return data
    stopped: :class:`~threading.Event`
        Stop check
    finished_adding: :class:`~threading.Event`
        Check for whether the job queue is done
    args: :class:`~montreal_forced_aligner.g2p.phonetisaurus_trainer.AlignmentInitArguments`
        Arguments for initialization
    """

    def data_generator(self, session):
        grapheme_table = pywrapfst.SymbolTable.read_text(self.far_path.with_name("graphemes.syms"))
        query = session.query(Utterance.normalized_character_text).filter(
            Utterance.ignored == False, Utterance.job_id == self.job_name  # noqa
        )
        for (text,) in query:
            tokenized = [x if grapheme_table.member(x) else "<unk>" for x in text.split()]
            untokenized = [x for x in tokenized if x != "<space>"]
            yield untokenized, tokenized

    def run(self) -> None:
        """Run the function"""

        engine = sqlalchemy.create_engine(
            self.db_string,
            poolclass=sqlalchemy.NullPool,
            pool_reset_on_return=None,
            isolation_level="AUTOCOMMIT",
            logging_name=f"{type(self).__name__}_engine",
        ).execution_options(logging_token=f"{type(self).__name__}_engine")
        try:
            symbol_table = pywrapfst.SymbolTable()
            symbol_table.add_symbol(self.eps)
            valid_output_ngrams = set()
            base_dir = os.path.dirname(self.far_path)
            with mfa_open(os.path.join(base_dir, "output_ngram.ngrams"), "r") as f:
                for line in f:
                    line = line.strip()
                    valid_output_ngrams.add(line)
            valid_input_ngrams = set()
            with mfa_open(os.path.join(base_dir, "input_ngram.ngrams"), "r") as f:
                for line in f:
                    line = line.strip()
                    valid_input_ngrams.add(line)
            count = 0
            data = {}
            with mfa_open(self.log_path, "w") as log_file, sqlalchemy.orm.Session(
                engine
            ) as session:
                far_writer = pywrapfst.FarWriter.create(self.far_path, arc_type="log")
                for current_index, (input, output) in enumerate(self.data_generator(session)):
                    if self.stopped.is_set():
                        continue
                    try:
                        key = f"{current_index:08x}"
                        fst = pynini.Fst(arc_type="log")
                        final_state = ((len(input) + 1) * (len(output) + 1)) - 1

                        for _ in range(final_state + 1):
                            fst.add_state()
                        for i in range(len(input) + 1):
                            for j in range(len(output) + 1):
                                istate = i * (len(output) + 1) + j

                                for input_range in range(1, self.input_order + 1):
                                    for output_range in range(input_range, self.output_order + 1):
                                        if i + input_range <= len(
                                            input
                                        ) and j + output_range <= len(output):
                                            if (
                                                self.restrict
                                                and input_range > 1
                                                and output_range > 1
                                            ):
                                                continue
                                            subseq_output = output[j : j + output_range]
                                            output_string = self.seq_sep.join(subseq_output)
                                            if (
                                                output_range > 1
                                                and output_string not in valid_output_ngrams
                                            ):
                                                continue
                                            subseq_input = input[i : i + input_range]
                                            input_string = self.seq_sep.join(subseq_input)
                                            if output_range > 1:
                                                if "<space>" not in subseq_output:
                                                    continue
                                                if input_string not in output_string:
                                                    continue
                                            if (
                                                output_range == input_range
                                                and input_string != output_string
                                            ):
                                                continue
                                            if (
                                                input_range > 1
                                                and input_string not in valid_input_ngrams
                                            ):
                                                continue
                                            symbol = self.s1s2_sep.join(
                                                [input_string, output_string]
                                            )
                                            ilabel = symbol_table.find(symbol)
                                            if ilabel == pywrapfst.NO_LABEL:
                                                ilabel = symbol_table.add_symbol(symbol)
                                            ostate = (i + input_range) * (len(output) + 1) + (
                                                j + output_range
                                            )
                                            fst.add_arc(
                                                istate,
                                                pywrapfst.Arc(
                                                    ilabel,
                                                    ilabel,
                                                    pywrapfst.Weight(
                                                        "log", float(input_range * output_range)
                                                    ),
                                                    ostate,
                                                ),
                                            )
                        fst.set_start(0)
                        fst.set_final(final_state, pywrapfst.Weight.one(fst.weight_type()))
                        fst = pynini.connect(fst)
                        for state in fst.states():
                            for arc in fst.arcs(state):
                                sym = symbol_table.find(arc.ilabel)
                                if sym not in data:
                                    data[sym] = arc.weight
                                else:
                                    data[sym] = pywrapfst.plus(data[sym], arc.weight)
                        if count >= self.batch_size:
                            data = {k: float(v) for k, v in data.items()}
                            self.return_queue.put((self.job_name, data, count))
                            data = {}
                            count = 0
                        log_file.flush()
                        far_writer[key] = fst
                        del fst
                        count += 1
                    except Exception as e:  # noqa
                        self.stopped.set()
                        self.return_queue.put(e)
            if data:
                data = {k: float(v) for k, v in data.items()}
                self.return_queue.put((self.job_name, data, count))
            symbol_table.write_text(self.far_path.with_suffix(".syms"))
            return
        except Exception as e:
            self.stopped.set()
            self.return_queue.put(e)
        finally:
            self.finished.set()
            del far_writer


class TokenizerMixin(AcousticCorpusMixin, G2PTrainer, DictionaryMixin, TopLevelMfaWorker):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.training_graphemes = set()
        self.uer = None
        self.cer = None
        self.deletions = False
        self.insertions = True
        self.num_training_utterances = 0
        self.num_validation_utterances = 0

    def setup(self) -> None:
        super().setup()
        self.ignore_empty_utterances = True
        if self.initialized:
            return
        try:
            self._load_corpus()
            self._create_dummy_dictionary()
            self.initialize_jobs()
            self.normalize_text()
            self.initialize_training()
        except Exception as e:
            if isinstance(e, KaldiProcessingError):
                log_kaldi_errors(e.error_logs)
                e.update_log_file()
            raise
        self.initialized = True

    def evaluate_tokenizer(self) -> None:
        """
        Validate the tokenizer model against held out data
        """
        temp_model_path = self.working_log_directory.joinpath("tokenizer_model.zip")
        self.export_model(temp_model_path)
        temp_dir = self.working_directory.joinpath("validation")
        temp_dir.mkdir(parents=True, exist_ok=True)
        with self.session() as session:
            validation_set = {}
            query = session.query(Utterance.normalized_character_text).filter(
                Utterance.ignored == True  # noqa
            )
            for (text,) in query:
                tokenized = text.split()
                untokenized = [x for x in tokenized if x != "<space>"]
                tokenized = [x if x != "<space>" else " " for x in tokenized]
                validation_set[" ".join(untokenized)] = "".join(tokenized)
        gen = TokenizerValidator(
            tokenizer_model_path=temp_model_path,
            corpus_directory=self.corpus_directory,
            utterances_to_tokenize=list(validation_set.keys()),
        )
        output = gen.tokenize_utterances()
        with mfa_open(temp_dir.joinpath("validation_output.txt"), "w") as f:
            for orthography, pronunciations in output.items():
                if not pronunciations:
                    continue
                for p in pronunciations:
                    if not p:
                        continue
                    f.write(f"{orthography}\t{p}\n")
        gen.compute_validation_errors(validation_set, output)
        self.uer = gen.uer
        self.cer = gen.cer
        gen.cleanup()


class PhonetisaurusTokenizerTrainer(PhonetisaurusTrainerMixin, TokenizerMixin):
    alignment_init_function = TokenizerAlignmentInitWorker

    def __init__(
        self, input_order: int = 2, output_order: int = 3, oov_count_threshold: int = 5, **kwargs
    ):
        super().__init__(
            oov_count_threshold=oov_count_threshold,
            grapheme_order=input_order,
            phone_order=output_order,
            **kwargs,
        )

    @property
    def data_source_identifier(self) -> str:
        """Corpus name"""
        return self.corpus_directory.name

    @property
    def meta(self) -> MetaDict:
        """Metadata for exported tokenizer model"""
        from datetime import datetime

        from ..utils import get_mfa_version

        meta = {
            "version": get_mfa_version(),
            "architecture": self.architecture,
            "train_date": str(datetime.now()),
            "evaluation": {},
            "input_order": self.input_order,
            "output_order": self.output_order,
            "oov_count_threshold": self.oov_count_threshold,
            "training": {
                "num_utterances": self.num_training_utterances,
                "num_graphemes": len(self.training_graphemes),
            },
        }
        if self.model_version is not None:
            meta["version"] = self.model_version

        if self.evaluation_mode:
            meta["evaluation"]["num_utterances"] = self.num_validation_utterances
            meta["evaluation"]["utterance_error_rate"] = self.uer
            meta["evaluation"]["character_error_rate"] = self.cer
        return meta

    def train(self) -> None:
        if os.path.exists(self.fst_path):
            self.finalize_training()
            return
        super().train()

    def initialize_training(self) -> None:
        """Initialize training tokenizer model"""

        self.create_new_current_workflow(WorkflowType.tokenizer_training)
        if self.fst_path.exists():
            return
        logger.info("Initializing training...")
        with self.session() as session:
            session.query(M2M2Job).delete()
            session.query(M2MSymbol).delete()
            session.commit()
            self.num_validation_utterances = 0
            self.num_training_utterances = 0
            if self.evaluation_mode:
                validation_items = int(self.num_utterances * self.validation_proportion)
                validation_utterances = (
                    sqlalchemy.select(Utterance.id)
                    .order_by(sqlalchemy.func.random())
                    .limit(validation_items)
                    .scalar_subquery()
                )
                query = (
                    sqlalchemy.update(Utterance)
                    .execution_options(synchronize_session="fetch")
                    .values(ignored=True)
                    .where(Utterance.id.in_(validation_utterances))
                )
                with session.begin_nested():
                    session.execute(query)
                    session.flush()
                session.commit()
                self.num_validation_utterances = (
                    session.query(Utterance.id).filter(Utterance.ignored == True).count()  # noqa
                )

            query = session.query(Utterance.normalized_character_text).filter(
                Utterance.ignored == False  # noqa
            )
            unk_character = "<unk>"
            self.training_graphemes.add(unk_character)
            counts = collections.Counter()
            for (text,) in query:
                counts.update(text.split())
            with mfa_open(
                self.working_directory.joinpath("input.txt"), "w"
            ) as untokenized_f, mfa_open(
                self.working_directory.joinpath("output.txt"), "w"
            ) as tokenized_f:
                for (text,) in query:
                    assert text
                    tokenized = [
                        x if counts[x] >= self.oov_count_threshold else unk_character
                        for x in text.split()
                    ]
                    untokenized = [x for x in tokenized if x != "<space>"]
                    self.num_training_utterances += 1
                    self.training_graphemes.update(tokenized)
                    untokenized_f.write(" ".join(untokenized) + "\n")
                    tokenized_f.write(" ".join(tokenized) + "\n")
            index = 1
            with mfa_open(self.working_directory.joinpath("graphemes.syms"), "w") as f:
                f.write("<eps>\t0\n")
                for g in sorted(self.training_graphemes):
                    f.write(f"{g}\t{index}\n")
                    index += 1
            self.compute_initial_ngrams()
            self.g2p_num_training_pronunciations = self.num_training_utterances

    def finalize_training(self) -> None:
        """Finalize training"""
        shutil.copyfile(self.fst_path, self.working_directory.joinpath("tokenizer.fst"))
        shutil.copyfile(self.grapheme_symbols_path, self.working_directory.joinpath("input.syms"))
        shutil.copyfile(self.phone_symbols_path, self.working_directory.joinpath("output.syms"))
        if self.evaluation_mode:
            self.evaluate_tokenizer()

    def export_model(self, output_model_path: Path) -> None:
        """
        Export tokenizer model to specified path

        Parameters
        ----------
        output_model_path: :class:`~pathlib.Path`
            Path to export model
        """
        directory = output_model_path.parent

        models_temp_dir = self.working_directory.joinpath("model_archive_temp")
        model = TokenizerModel.empty(output_model_path.stem, root_directory=models_temp_dir)
        model.add_meta_file(self)
        model.add_tokenizer_model(self.working_directory)
        model.add_graphemes_path(self.working_directory)
        if directory:
            os.makedirs(directory, exist_ok=True)
        model.dump(output_model_path)
        if not config.DEBUG:
            model.clean_up()
        # self.clean_up()
        logger.info(f"Saved model to {output_model_path}")


[docs] class TokenizerTrainer(PyniniTrainerMixin, TokenizerMixin): def __init__(self, oov_count_threshold=5, **kwargs): super().__init__(oov_count_threshold=oov_count_threshold, **kwargs) self.training_graphemes = set() self.uer = None self.cer = None self.deletions = False self.insertions = True @property def meta(self) -> MetaDict: """Metadata for exported tokenizer model""" from datetime import datetime from ..utils import get_mfa_version m = { "version": get_mfa_version(), "architecture": self.architecture, "train_date": str(datetime.now()), "evaluation": {}, "training": { "num_utterances": self.num_training_utterances, "num_graphemes": len(self.training_graphemes), }, } if self.evaluation_mode: m["evaluation"]["num_utterances"] = self.num_validation_utterances m["evaluation"]["utterance_error_rate"] = self.uer m["evaluation"]["character_error_rate"] = self.cer return m @property def data_source_identifier(self) -> str: """Corpus name""" return self.corpus_directory.name @property def sym_path(self) -> Path: return self.working_directory.joinpath("graphemes.syms")
[docs] def initialize_training(self) -> None: """Initialize training tokenizer model""" self.create_new_current_workflow(WorkflowType.tokenizer_training) with self.session() as session: self.num_validation_utterances = 0 self.num_training_utterances = 0 self.num_iterations = 1 self.random_starts = 1 self.input_token_type = self.sym_path self.output_token_type = self.sym_path if self.evaluation_mode: validation_items = int(self.num_utterances * self.validation_proportion) validation_utterances = ( sqlalchemy.select(Utterance.id) .order_by(sqlalchemy.func.random()) .limit(validation_items) .scalar_subquery() ) query = ( sqlalchemy.update(Utterance) .execution_options(synchronize_session="fetch") .values(ignored=True) .where(Utterance.id.in_(validation_utterances)) ) with session.begin_nested(): session.execute(query) session.flush() session.commit() self.num_validation_utterances = ( session.query(Utterance.id).filter(Utterance.ignored == True).count() # noqa ) query = session.query(Utterance.normalized_character_text).filter( Utterance.ignored == False # noqa ) unk_character = "<unk>" self.training_graphemes.add(unk_character) counts = collections.Counter() for (text,) in query: counts.update(text.split()) with mfa_open(self.input_path, "w") as untokenized_f, mfa_open( self.output_path, "w" ) as tokenized_f: for (text,) in query: assert text tokenized = [ x if counts[x] >= self.oov_count_threshold else unk_character for x in text.split() ] untokenized = [x for x in tokenized if x != "<space>"] self.num_training_utterances += 1 self.training_graphemes.update(tokenized) untokenized_f.write(" ".join(untokenized) + "\n") tokenized_f.write(" ".join(tokenized) + "\n") index = 1 with mfa_open(self.sym_path, "w") as f: f.write("<eps>\t0\n") for g in sorted(self.training_graphemes): f.write(f"{g}\t{index}\n") index += 1
def _lexicon_covering(self, input_path=None, output_path=None) -> None: """Builds covering grammar and lexicon FARs.""" # Sets of labels for the covering grammar. with mfa_open( self.working_log_directory.joinpath("covering_grammar.log"), "w" ) as log_file: if input_path is None: input_path = self.input_path if output_path is None: output_path = self.output_path com = [ thirdparty_binary("farcompilestrings"), "--fst_type=compact", ] com.append("--token_type=symbol") com.append( f"--symbols={self.sym_path}", ) com.append("--unknown_symbol=<unk>") com.extend([input_path, self.input_far_path]) print(" ".join(map(str, com)), file=log_file) subprocess.check_call(com, env=os.environ, stderr=log_file, stdout=log_file) com = [ thirdparty_binary("farcompilestrings"), "--fst_type=compact", "--token_type=symbol", f"--symbols={self.sym_path}", output_path, self.output_far_path, ] print(" ".join(map(str, com)), file=log_file) subprocess.check_call(com, env=os.environ, stderr=log_file, stdout=log_file) cg = pywrapfst.VectorFst() state = cg.add_state() cg.set_start(state) labels = pywrapfst.SymbolTable.read_text(self.sym_path) one = pywrapfst.Weight.one(cg.weight_type()) for i in range(labels.num_symbols()): if labels.find(i) == "<eps>": continue cg.add_arc(state, pywrapfst.Arc(i, i, one, state)) olabel = labels.find("<space>") cg.add_arc(state, pywrapfst.Arc(0, olabel, one, state)) cg.set_final(state) assert cg.verify(), "Label acceptor is ill-formed" cg.write(self.cg_path)
[docs] def train(self) -> None: """ Train a tokenizer model """ os.makedirs(self.working_log_directory, exist_ok=True) begin = time.time() if os.path.exists(self.far_path) and os.path.exists(self.encoder_path): logger.info("Alignment already done, skipping!") else: self.align_g2p() logger.debug(f"Aligning took {time.time() - begin:.3f} seconds") begin = time.time() self.generate_model() logger.debug(f"Generating model took {time.time() - begin:.3f} seconds") self.finalize_training()
[docs] def finalize_training(self) -> None: """Finalize training""" shutil.copyfile(self.fst_path, self.working_directory.joinpath("tokenizer.fst")) if self.evaluation_mode: self.evaluate_tokenizer()
[docs] def export_model(self, output_model_path: Path) -> None: """ Export tokenizer model to specified path Parameters ---------- output_model_path: :class:`~pathlib.Path` Path to export model """ directory = output_model_path.parent models_temp_dir = self.working_directory.joinpath("model_archive_temp") model = TokenizerModel.empty(output_model_path.stem, root_directory=models_temp_dir) model.add_meta_file(self) model.add_tokenizer_model(self.working_directory) model.add_graphemes_path(self.working_directory) if directory: os.makedirs(directory, exist_ok=True) model.dump(output_model_path) if not config.DEBUG: model.clean_up() # self.clean_up() logger.info(f"Saved model to {output_model_path}")