Source code for montreal_forced_aligner.textgrid

"""
Textgrid utilities
==================

"""
from __future__ import annotations

import csv
import json
import os
import re
import sys
import traceback
import typing
from pathlib import Path
from typing import Dict, List

import dataclassy
import sqlalchemy
from kalpy.gmm.data import to_tg_interval
from praatio import textgrid as tgio
from praatio.data_classes.interval_tier import Interval
from praatio.utilities import utils as tgio_utils
from sqlalchemy.orm import Session

from montreal_forced_aligner.data import PhoneType, TextFileType, TextgridFormats, WordType
from montreal_forced_aligner.db import Phone, PhoneInterval, Speaker, Utterance, Word, WordInterval
from montreal_forced_aligner.exceptions import AlignmentExportError, CtmError, TextGridParseError
from montreal_forced_aligner.helper import mfa_open

__all__ = [
    "load_textgrid",
    "Textgrid",
    "process_ctm_line",
    "export_textgrid",
    "construct_textgrid_output",
    "construct_output_path",
    "output_textgrid_writing_errors",
]


# noinspection PyUnresolvedReferences
@dataclassy.dataclass(slots=True)
class CtmInterval:
    """
    Data class for intervals derived from CTM files

    Parameters
    ----------
    begin: float
        Start time of interval
    end: float
        End time of interval
    label: str
        Text of interval
    confidence: float, optional
        Confidence score of the interval
    """

    begin: float
    end: float
    label: typing.Union[int, str]
    confidence: typing.Optional[float] = None

    def __lt__(self, other: CtmInterval):
        """Sorting function for CtmIntervals"""
        return self.begin < other.begin

    def __add__(self, other):
        if isinstance(other, str):
            return self.label + other
        else:
            self.begin += other
            self.end += other

    def __post_init__(self) -> None:
        """
        Check on data validity

        Raises
        ------
        :class:`~montreal_forced_aligner.exceptions.CtmError`
            If begin or end are not valid
        """
        if self.end < -1 or self.begin == 1000000:
            raise CtmError(self)

    def to_tg_interval(self, file_duration=None) -> Interval:
        """
        Converts the CTMInterval to
        `PraatIO's Interval class <http://timmahrt.github.io/praatIO/praatio/utilities/constants.html#Interval>`_

        Returns
        -------
        :class:`praatio.utilities.constants.Interval`
            Derived PraatIO Interval
        """
        if self.end < -1 or self.begin == 1000000:
            raise CtmError(self)
        end = round(self.end, 6)
        begin = round(self.begin, 6)
        if file_duration is not None and end > file_duration:
            end = round(file_duration, 6)
        if begin >= end:
            raise CtmError(self)
        return Interval(round(self.begin, 6), end, self.label)


class Textgrid(tgio.Textgrid):
    def save(
        self,
        fn: str,
        format: typing.Literal["short_textgrid", "long_textgrid", "json", "textgrid_json"],
        includeBlankSpaces: bool,
        minTimestamp: typing.Optional[float] = None,
        maxTimestamp: typing.Optional[float] = None,
        minimumIntervalLength: float = None,
        reportingMode: typing.Literal["silence", "warning", "error"] = "warning",
    ) -> None:
        """Save the current textgrid to a file

        Args:
            fn: the fullpath filename of the output
            format: one of ['short_textgrid', 'long_textgrid', 'json', 'textgrid_json']
                'short_textgrid' and 'long_textgrid' are both used by praat
                'json' and 'textgrid_json' are two json variants. 'json' cannot represent
                tiers with different min and max timestamps than the textgrid.
            includeBlankSpaces: if True, blank sections in interval
                tiers will be filled in with an empty interval
                (with a label of ""). If you are unsure, True is recommended
                as Praat needs blanks to render textgrids properly.
            minTimestamp: the minTimestamp of the saved Textgrid;
                if None, use whatever is defined in the Textgrid object.
                If minTimestamp is larger than timestamps in your textgrid,
                an exception will be thrown.
            maxTimestamp: the maxTimestamp of the saved Textgrid;
                if None, use whatever is defined in the Textgrid object.
                If maxTimestamp is smaller than timestamps in your textgrid,
                an exception will be thrown.
            minimumIntervalLength: any labeled intervals smaller
                than this will be removed, useful for removing ultrashort
                or fragmented intervals; if None, don't remove any.
                Removed intervals are merged (without their label) into
                adjacent entries.
            reportingMode: one of "silence", "warning", or "error". This flag
                determines the behavior if there is a size difference between the
                maxTimestamp in the tier and the current textgrid.

        Returns:
            a string representation of the textgrid
        """

        tab = " " * 4

        with mfa_open(fn, mode="w") as fd:
            if format in {TextgridFormats.LONG_TEXTGRID, TextgridFormats.SHORT_TEXTGRID}:
                # Header
                if format == TextgridFormats.LONG_TEXTGRID:
                    fd.write('File type = "ooTextFile"\n')
                    fd.write('Object class = "TextGrid"\n\n')

                    fd.write(f"xmin = {self.minTimestamp} \n")
                    fd.write(f"xmax = {self.maxTimestamp} \n")
                    fd.write("tiers? <exists> \n")
                    fd.write(f"size = {len(self._tierDict)} \n")
                    fd.write("item []: \n")
                elif format == TextgridFormats.SHORT_TEXTGRID:
                    fd.write('File type = "ooTextFile"\n')
                    fd.write('Object class = "TextGrid"\n\n')
                    fd.write(f"{self.minTimestamp}\n{self.maxTimestamp}\n")
                    fd.write(f"<exists>\n{len(self._tierDict)}\n")

                for tierNum, (name, tier) in enumerate(self._tierDict.items()):
                    if includeBlankSpaces and tier._entries:
                        if tier._entries[0][0] > 0.001:
                            tier._entries.insert(0, Interval(0.0, tier._entries[0][0], ""))
                        interval_index = 1
                        while interval_index < len(tier._entries):
                            start, end, label = tier._entries[interval_index]
                            previous_entry = tier._entries[interval_index - 1]
                            if start - previous_entry[1] > 0.001:
                                tier._entries.insert(
                                    interval_index, Interval(previous_entry[1], start, "")
                                )
                                interval_index += 1
                            interval_index += 1
                        if self.maxTimestamp - tier._entries[-1][1] > 0.001:
                            tier._entries.append(
                                Interval(tier._entries[-1][1], self.maxTimestamp, "")
                            )

                    tier_name = tgio_utils.escapeQuotes(name)
                    if format == TextgridFormats.LONG_TEXTGRID:
                        # Interval header
                        fd.write(tab + f"item [{tierNum + 1}]:\n")
                        fd.write(tab * 2 + f'class = "{tier.tierType}" \n')
                        fd.write(tab * 2 + f'name = "{tier_name}" \n')
                        fd.write(tab * 2 + f"xmin = {self.minTimestamp} \n")
                        fd.write(tab * 2 + f"xmax = {self.maxTimestamp} \n")

                        fd.write(tab * 2 + f"intervals: size = {len(tier._entries)} \n")
                    elif format == TextgridFormats.SHORT_TEXTGRID:
                        fd.write(f'"{tier.tierType}"\n')
                        fd.write(f'"{tier_name}"\n')
                        fd.write(
                            f"{self.minTimestamp}\n{self.maxTimestamp}\n{len(tier._entries)}\n"
                        )

                    for i, entry in enumerate(tier._entries):
                        start, end, label = entry
                        label = tgio_utils.escapeQuotes(label)
                        if format == TextgridFormats.LONG_TEXTGRID:
                            fd.write(
                                f"{tab * 2}intervals [{i + 1}]:\n"
                                f"{tab * 3}xmin = {start} \n"
                                f"{tab * 3}xmax = {end} \n"
                                f'{tab * 3}text = "{label}" \n'
                            )
                        elif format == TextgridFormats.SHORT_TEXTGRID:
                            fd.write(f'{start}\n{end}\n"{label}"\n')


[docs] def process_ctm_line( line: str, reversed_phone_mapping: Dict[int, int], raw_id=False ) -> typing.Tuple[int, CtmInterval]: """ Helper function for parsing a line of CTM file to construct a CTMInterval CTM format is: utt_id channel_num start_time phone_dur phone_id [confidence] Parameters ---------- line: str Input string reversed_phone_mapping: dict[int, str] Mapping from integer IDs to phone labels Returns ------- :class:`~kalpy.gmm.data.CtmInterval` Extracted data from the line """ line = line.split() utt = line[0] if not raw_id: utt = int(line[0].split("-")[-1]) begin = round(float(line[2]), 4) duration = float(line[3]) end = round(begin + duration, 4) label = line[4] conf = None if len(line) > 5: conf = round(float(line[5]), 4) label = reversed_phone_mapping[int(label)] return utt, CtmInterval(begin, end, label, confidence=conf)
[docs] def output_textgrid_writing_errors( output_directory: str, export_errors: Dict[str, AlignmentExportError] ) -> None: """ Output any errors that were encountered in writing TextGrids Parameters ---------- output_directory: str Directory to save TextGrids files export_errors: dict[str, :class:`~montreal_forced_aligner.exceptions.AlignmentExportError`] Dictionary of errors encountered """ error_log = os.path.join(output_directory, "output_errors.txt") if os.path.exists(error_log): os.remove(error_log) for result in export_errors.values(): if not os.path.exists(error_log): with mfa_open(error_log, "w") as f: f.write( "The following exceptions were encountered during the output of the alignments to TextGrids:\n\n" ) with mfa_open(error_log, "a") as f: f.write(f"{str(result)}\n\n")
def load_textgrid(path: typing.Union[Path, str]): """ Load a TextGrid Parameters ---------- path: :class:`~pathlib.Path` or str TextGrid file to parse Returns ------- :class:`~praatio.data_classes.textgrid.Textgrid` Praatio TextGrid object """ try: tg = tgio.openTextgrid(str(path), includeEmptyIntervals=False) except Exception: exc_type, exc_value, exc_traceback = sys.exc_info() raise TextGridParseError( path, "\n".join(traceback.format_exception(exc_type, exc_value, exc_traceback)), ) num_tiers = len(tg.tiers) if num_tiers == 0: raise TextGridParseError(path, "Number of tiers parsed was zero") return tg def parse_aligned_textgrid( path: Path, root_speaker: typing.Optional[str] = None ) -> Dict[str, List[CtmInterval]]: """ Load a TextGrid as a dictionary of speaker's phone tiers Parameters ---------- path: :class:`~pathlib.Path` TextGrid file to parse root_speaker: str, optional Speaker name if the TextGrid has no speaker information Returns ------- dict[str, list[:class:`~kalpy.gmm.data.CtmInterval`]] Parsed phone tier """ tg = load_textgrid(path) data = {} phone_tier_pattern = re.compile(r"(.*) ?- ?phones") for tier_name in tg.tierNames: ti = tg._tierDict[tier_name] if not isinstance(ti, tgio.IntervalTier): continue if "phones" not in tier_name: continue m = phone_tier_pattern.match(tier_name) if m: speaker_name = m.groups()[0].strip() elif root_speaker: speaker_name = root_speaker else: speaker_name = "" if speaker_name not in data: data[speaker_name] = [] for begin, end, text in ti.entries: text = text.strip() if not text: continue begin, end = round(begin, 4), round(end, 4) if end - begin < 0.01: continue interval = CtmInterval(begin, end, text) data[speaker_name].append(interval) return data def construct_textgrid_output( session: Session, file_batch: typing.Dict[int, typing.Tuple], cleanup_textgrids: bool, clitic_marker: str, output_directory: Path, frame_shift: float, output_format: str = TextgridFormats.SHORT_TEXTGRID, include_original_text: bool = False, ): phone_interval_query = ( sqlalchemy.select( PhoneInterval.begin, PhoneInterval.end, Phone.phone, Speaker.name, Utterance.file_id ) .execution_options(yield_per=1000) .join(PhoneInterval.phone) .join(PhoneInterval.utterance) .join(Utterance.speaker) .filter(PhoneInterval.duration > 0) .filter(Utterance.file_id.in_(list(file_batch.keys()))) ) word_interval_query = ( sqlalchemy.select( WordInterval.begin, WordInterval.end, Word.word, Speaker.name, Utterance.file_id ) .execution_options(yield_per=1000) .join(WordInterval.word) .join(WordInterval.utterance) .join(Utterance.speaker) .filter(WordInterval.duration > 0) .filter(Utterance.file_id.in_(list(file_batch.keys()))) ) if cleanup_textgrids: phone_interval_query = phone_interval_query.filter(Phone.phone_type != PhoneType.silence) word_interval_query = word_interval_query.filter(Word.word_type != WordType.silence) phone_intervals = session.execute( phone_interval_query.order_by(Utterance.file_id, PhoneInterval.begin) ) word_intervals = session.execute( word_interval_query.order_by(Utterance.file_id, WordInterval.begin) ) utterances = None if include_original_text: utterances = session.execute( sqlalchemy.select( Utterance.begin, Utterance.end, Utterance.text, Speaker.name, Utterance.file_id ) .execution_options(yield_per=1000) .join(Utterance.speaker) .filter(Utterance.file_id.in_(list(file_batch.keys()))) .order_by(Utterance.file_id) ) pi_current_file_id = None wi_current_file_id = None u_current_file_id = None word_data = [] phone_data = [] utterance_data = [] def process_phone_data(): for beg, end, p, speaker_name in phone_data: if speaker_name not in data: data[speaker_name] = {"words": [], "phones": []} if include_original_text: data[speaker_name]["utterances"] = [] data[speaker_name]["phones"].append(CtmInterval(beg, end, p)) def process_word_data(): for beg, end, w, speaker_name in word_data: if ( cleanup_textgrids and data[speaker_name]["words"] and beg - data[speaker_name]["words"][-1].end < 0.02 and clitic_marker and ( data[speaker_name]["words"][-1].label.endswith(clitic_marker) or w.startswith(clitic_marker) ) ): data[speaker_name]["words"][-1].end = end data[speaker_name]["words"][-1].label += w else: data[speaker_name]["words"].append(CtmInterval(beg, end, w)) def process_utterance_data(): for beg, end, u, speaker_name in utterance_data: data[speaker_name]["utterances"].append(CtmInterval(beg, end, u)) while True: data = {} for pi_begin, pi_end, phone, pi_speaker_name, pi_file_id in phone_intervals: if pi_current_file_id is None: pi_current_file_id = pi_file_id if pi_file_id != pi_current_file_id: process_phone_data() phone_data = [(pi_begin, pi_end, phone, pi_speaker_name)] current_file_id = pi_current_file_id pi_current_file_id = pi_file_id break phone_data.append((pi_begin, pi_end, phone, pi_speaker_name)) else: if phone_data: process_phone_data() current_file_id = pi_current_file_id phone_data = [] else: break for wi_begin, wi_end, word, wi_speaker_name, wi_file_id in word_intervals: if wi_current_file_id is None: wi_current_file_id = wi_file_id if wi_file_id != wi_current_file_id: process_word_data() word_data = [(wi_begin, wi_end, word, wi_speaker_name)] wi_current_file_id = wi_file_id break word_data.append((wi_begin, wi_end, word, wi_speaker_name)) else: if word_data: process_word_data() if include_original_text: for u_begin, u_end, text, u_speaker_name, u_file_id in utterances: if u_current_file_id is None: u_current_file_id = u_file_id if u_file_id != u_current_file_id: process_utterance_data() utterance_data = [(u_begin, u_end, text, u_speaker_name)] u_current_file_id = u_file_id break utterance_data.append((u_begin, u_end, text, u_speaker_name)) else: if utterance_data: process_utterance_data() file_name, relative_path, file_duration, text_file_path = file_batch[current_file_id] output_path = construct_output_path( file_name, relative_path, output_directory, text_file_path, output_format ) export_textgrid(data, output_path, file_duration, frame_shift, output_format) yield output_path
[docs] def construct_output_path( name: str, relative_path: Path, output_directory: Path, input_path: Path = None, output_format: str = TextgridFormats.SHORT_TEXTGRID, ) -> Path: """ Construct an output path Returns ------- Path Output path """ if isinstance(output_directory, str): output_directory = Path(output_directory) if output_format.upper() == "LAB": extension = ".lab" elif output_format.upper() == "JSON": extension = ".json" elif output_format.upper() == "CSV": extension = ".csv" else: extension = ".TextGrid" if relative_path: relative = output_directory.joinpath(relative_path) else: relative = output_directory output_path = relative.joinpath(name + extension) if output_path == input_path: output_path = relative.joinpath(name + "_aligned" + extension) os.makedirs(relative, exist_ok=True) relative.mkdir(parents=True, exist_ok=True) return output_path
[docs] def export_textgrid( speaker_data: Dict[str, Dict[str, List[CtmInterval]]], output_path: Path, duration: float, frame_shift: float, output_format: str = TextFileType.TEXTGRID.value, ) -> None: """ Export aligned file to TextGrid Parameters ---------- speaker_data: dict[Speaker, dict[str, list[:class:`~kalpy.gmm.data.CtmInterval`]] Per speaker, per word/phone :class:`~kalpy.gmm.data.CtmInterval` output_path: :class:`~pathlib.Path` Output path of the file duration: float Duration of the file frame_shift: float Frame shift of features, in seconds output_format: str, optional Output format, one of: "long_textgrid" (default), "short_textgrid", "json", or "csv" """ has_data = False duration = round(duration, 6) if output_format == "csv": csv_data = [] for speaker, data in speaker_data.items(): for annotation_type, intervals in data.items(): if len(intervals): has_data = True for a in intervals: if duration - a.end < (frame_shift * 2): # Fix rounding issues a.end = duration csv_data.append( { "Begin": a.begin, "End": a.end, "Label": a.label, "Type": annotation_type, "Speaker": speaker, } ) if has_data: with mfa_open(output_path, "w") as f: writer = csv.DictWriter(f, fieldnames=["Begin", "End", "Label", "Type", "Speaker"]) writer.writeheader() for line in csv_data: writer.writerow(line) elif output_format == "json": json_data = {"start": 0, "end": duration, "tiers": {}} for speaker, data in speaker_data.items(): for annotation_type, intervals in data.items(): if len(speaker_data) > 1: tier_name = f"{speaker} - {annotation_type}" else: tier_name = annotation_type if tier_name not in json_data["tiers"]: json_data["tiers"][tier_name] = {"type": "interval", "entries": []} if len(intervals): has_data = True for a in intervals: if duration - a.end < (frame_shift * 2): # Fix rounding issues a.end = duration json_data["tiers"][tier_name]["entries"].append([a.begin, a.end, a.label]) if has_data: with mfa_open(output_path, "w") as f: json.dump(json_data, f, indent=4, ensure_ascii=False) else: # Create initial textgrid tg = Textgrid() tg.minTimestamp = 0 tg.maxTimestamp = duration for speaker, data in speaker_data.items(): for annotation_type, intervals in data.items(): if len(intervals): has_data = True if len(speaker_data) > 1: tier_name = f"{speaker} - {annotation_type}" else: tier_name = annotation_type if tier_name not in tg.tierNames: tg.addTier(tgio.IntervalTier(tier_name, [], minT=0, maxT=duration)) tier = tg.getTier(tier_name) for i, a in enumerate(sorted(intervals, key=lambda x: x.begin)): if i == len(intervals) - 1 and duration - a.end < ( frame_shift * 2 ): # Fix rounding issues a.end = duration tg_interval = to_tg_interval(a, duration) if i > 0 and tier._entries[-1].end > tg_interval.start: a.begin = tier._entries[-1].end tg_interval = to_tg_interval(a, duration) tier._entries.append(tg_interval) if has_data: for tier in tg.tiers: if len(tier._entries) > 0 and tier._entries[-1][1] > tg.maxTimestamp: tier.insertEntry( Interval( tier._entries[-1].start, tg.maxTimestamp, tier._entries[-1].label ), collisionMode="replace", ) tg.save( str(output_path), includeBlankSpaces=True, format=output_format, minimumIntervalLength=None, reportingMode="silence", )