Source code for montreal_forced_aligner.corpus.remapper

"""Classes for remapping alignments from one phone set to another"""
from __future__ import annotations

import logging
import os
import threading
import time
import typing
from pathlib import Path
from queue import Empty, Queue

import yaml
from tqdm.rich import tqdm

from montreal_forced_aligner import config
from montreal_forced_aligner.abc import PhoneRemapperMixin, TopLevelMfaWorker
from montreal_forced_aligner.corpus.helper import find_exts
from montreal_forced_aligner.corpus.multiprocessing import AlignmentRemapperWorker
from montreal_forced_aligner.helper import mfa_open

logger = logging.getLogger("mfa")

__all__ = ["AlignmentRemapper"]


[docs] class AlignmentRemapper(PhoneRemapperMixin, TopLevelMfaWorker): def __init__( self, corpus_directory: typing.Union[str, Path], split_percentage: float = 0.5, **kwargs, ): self.corpus_directory = Path(corpus_directory) self._data_source = self.corpus_directory.stem self.split_percentage = split_percentage super().__init__(**kwargs) self.stopped = None @property def data_source_identifier(self) -> str: """Dictionary name""" return self._data_source @property def data_directory(self) -> Path: """Data directory for trainer""" return self.working_directory
[docs] def setup(self) -> None: """Setup for dictionary remapping""" super().setup() self.load_mapping() self.validate_mapping() if self.initialized: return self.initialized = True
def load_mapping(self) -> None: with mfa_open(self.phone_mapping_path, "r") as f: data = yaml.load(f, Loader=yaml.Loader) for key, value in data.items(): if isinstance(value, list): value = value[0] logger.warning( f"Found ambiguous mapping for {key}, using first value ({value}) as the target." ) if " " in value: value = tuple(value.split()) self.phone_remapping[key] = value def validate_mapping(self): covered_phones = set() found_splitting = False for key, value in self.phone_remapping.items(): if isinstance(value, tuple): found_splitting = True if " " in key: for p in key.split(): covered_phones.add(p) else: covered_phones.add(key) if found_splitting and self.split_percentage != 0.5: logger.warning( "Found instances of splitting one phone to multiple phones, " "be aware that new segments will receive equal distribution of duration. " "If a different point is better, use --split_percentage 0.75 to specify 75%, for instance, " "but this will only affect behavior when splitting to two phones." ) def remap_alignments( self, output_directory: typing.Union[Path, str], output_format: typing.Literal[ "short_textgrid", "long_textgrid", "json", "textgrid_json" ] = "short_textgrid", ): if self.stopped is None: self.stopped = threading.Event() output_directory = Path(output_directory) output_directory.mkdir(parents=True, exist_ok=True) begin_time = time.time() job_queue = Queue() return_queue = Queue() error_dict = {} finished_adding = threading.Event() procs = [] for i in range(config.NUM_JOBS): p = AlignmentRemapperWorker( i, job_queue, return_queue, self.stopped, finished_adding, self.phone_remapping, self.split_percentage, output_format, ) procs.append(p) p.start() try: file_count = 0 with tqdm(total=1, disable=config.QUIET) as pbar: for root, _, files in os.walk(self.corpus_directory, followlinks=True): if self.stopped.is_set(): break if root.startswith("."): # Ignore hidden directories continue exts = find_exts(files) relative_path = ( root.replace(str(self.corpus_directory), "").lstrip("/").lstrip("\\") ) for tg_name in exts.textgrid_files.values(): if self.stopped.is_set(): break input_path = os.path.join(root, tg_name) output_dir = output_directory.joinpath(relative_path) output_dir.mkdir(parents=True, exist_ok=True) output_path = output_dir.joinpath(tg_name) job_queue.put((input_path, output_path)) file_count += 1 pbar.total = file_count finished_adding.set() while True: try: result = return_queue.get(timeout=1) if isinstance(result, tuple): error_type = result[0] error = result[1] if error_type == "error": error_dict[error_type] = error else: if error_type not in error_dict: error_dict[error_type] = [] error_dict[error_type].append(error) continue if self.stopped.is_set(): continue except Empty: for proc in procs: if not proc.finished_processing.is_set(): break else: break continue pbar.update(1) return_queue.task_done() logger.debug("Waiting for workers to finish...") for p in procs: p.join() if "error" in error_dict: raise error_dict["error"] except KeyboardInterrupt: logger.info("Detected ctrl-c, please wait a moment while we clean everything up...") self.stopped.set() finished_adding.set() while True: try: _ = return_queue.get(timeout=1) return_queue.task_done() except Empty: for proc in procs: if not proc.finished_processing.is_set(): break else: break finally: finished_adding.set() for p in procs: p.join() if self.stopped.is_set(): logger.info(f"Stopped parsing early ({time.time() - begin_time:.3f} seconds)") else: logger.debug( f"Remapped alignments with {config.NUM_JOBS} jobs in {time.time() - begin_time:.3f} seconds" )