"""
Textgrid utilities
==================
"""
from __future__ import annotations
import csv
import json
import os
import re
import typing
from pathlib import Path
from typing import Dict, List
import dataclassy
import sqlalchemy
from kalpy.gmm.data import to_tg_interval
from praatio import textgrid as tgio
from praatio.data_classes.interval_tier import Interval
from praatio.utilities import utils as tgio_utils
from sqlalchemy.orm import Session
from montreal_forced_aligner.data import PhoneType, TextFileType, TextgridFormats, WordType
from montreal_forced_aligner.db import (
CorpusWorkflow,
Phone,
PhoneInterval,
Speaker,
Utterance,
Word,
WordInterval,
)
from montreal_forced_aligner.exceptions import AlignmentExportError, CtmError, TextGridParseError
from montreal_forced_aligner.helper import mfa_open
__all__ = [
"process_ctm_line",
"export_textgrid",
"construct_textgrid_output",
"construct_output_path",
"output_textgrid_writing_errors",
]
# noinspection PyUnresolvedReferences
@dataclassy.dataclass(slots=True)
class CtmInterval:
"""
Data class for intervals derived from CTM files
Parameters
----------
begin: float
Start time of interval
end: float
End time of interval
label: str
Text of interval
confidence: float, optional
Confidence score of the interval
"""
begin: float
end: float
label: typing.Union[int, str]
confidence: typing.Optional[float] = None
def __lt__(self, other: CtmInterval):
"""Sorting function for CtmIntervals"""
return self.begin < other.begin
def __add__(self, other):
if isinstance(other, str):
return self.label + other
else:
self.begin += other
self.end += other
def __post_init__(self) -> None:
"""
Check on data validity
Raises
------
:class:`~montreal_forced_aligner.exceptions.CtmError`
If begin or end are not valid
"""
if self.end < -1 or self.begin == 1000000:
raise CtmError(self)
def to_tg_interval(self, file_duration=None) -> Interval:
"""
Converts the CTMInterval to
`PraatIO's Interval class <http://timmahrt.github.io/praatIO/praatio/utilities/constants.html#Interval>`_
Returns
-------
:class:`praatio.utilities.constants.Interval`
Derived PraatIO Interval
"""
if self.end < -1 or self.begin == 1000000:
raise CtmError(self)
end = round(self.end, 6)
begin = round(self.begin, 6)
if file_duration is not None and end > file_duration:
end = round(file_duration, 6)
if begin >= end:
raise CtmError(self)
return Interval(round(self.begin, 6), end, self.label)
class Textgrid(tgio.Textgrid):
def save(
self,
fn: str,
format: typing.Literal["short_textgrid", "long_textgrid", "json", "textgrid_json"],
includeBlankSpaces: bool,
minTimestamp: typing.Optional[float] = None,
maxTimestamp: typing.Optional[float] = None,
minimumIntervalLength: float = None,
reportingMode: typing.Literal["silence", "warning", "error"] = "warning",
) -> None:
"""Save the current textgrid to a file
Args:
fn: the fullpath filename of the output
format: one of ['short_textgrid', 'long_textgrid', 'json', 'textgrid_json']
'short_textgrid' and 'long_textgrid' are both used by praat
'json' and 'textgrid_json' are two json variants. 'json' cannot represent
tiers with different min and max timestamps than the textgrid.
includeBlankSpaces: if True, blank sections in interval
tiers will be filled in with an empty interval
(with a label of ""). If you are unsure, True is recommended
as Praat needs blanks to render textgrids properly.
minTimestamp: the minTimestamp of the saved Textgrid;
if None, use whatever is defined in the Textgrid object.
If minTimestamp is larger than timestamps in your textgrid,
an exception will be thrown.
maxTimestamp: the maxTimestamp of the saved Textgrid;
if None, use whatever is defined in the Textgrid object.
If maxTimestamp is smaller than timestamps in your textgrid,
an exception will be thrown.
minimumIntervalLength: any labeled intervals smaller
than this will be removed, useful for removing ultrashort
or fragmented intervals; if None, don't remove any.
Removed intervals are merged (without their label) into
adjacent entries.
reportingMode: one of "silence", "warning", or "error". This flag
determines the behavior if there is a size difference between the
maxTimestamp in the tier and the current textgrid.
Returns:
a string representation of the textgrid
"""
tab = " " * 4
with mfa_open(fn, mode="w") as fd:
if format in {TextgridFormats.LONG_TEXTGRID, TextgridFormats.SHORT_TEXTGRID}:
# Header
if format == TextgridFormats.LONG_TEXTGRID:
fd.write('File type = "ooTextFile"\n')
fd.write('Object class = "TextGrid"\n\n')
fd.write(f"xmin = {self.minTimestamp} \n")
fd.write(f"xmax = {self.maxTimestamp} \n")
fd.write("tiers? <exists> \n")
fd.write(f"size = {len(self._tierDict)} \n")
fd.write("item []: \n")
elif format == TextgridFormats.SHORT_TEXTGRID:
fd.write('File type = "ooTextFile"\n')
fd.write('Object class = "TextGrid"\n\n')
fd.write(f"{self.minTimestamp}\n{self.maxTimestamp}\n")
fd.write(f"<exists>\n{len(self._tierDict)}\n")
for tierNum, (name, tier) in enumerate(self._tierDict.items()):
if includeBlankSpaces and tier._entries:
if tier._entries[0][0] > 0.001:
tier._entries.insert(0, Interval(0.0, tier._entries[0][0], ""))
interval_index = 1
while interval_index < len(tier._entries):
start, end, label = tier._entries[interval_index]
previous_entry = tier._entries[interval_index - 1]
if start - previous_entry[1] > 0.001:
tier._entries.insert(
interval_index, Interval(previous_entry[1], start, "")
)
interval_index += 1
interval_index += 1
if self.maxTimestamp - tier._entries[-1][1] > 0.001:
tier._entries.append(
Interval(tier._entries[-1][1], self.maxTimestamp, "")
)
tier_name = tgio_utils.escapeQuotes(name)
if format == TextgridFormats.LONG_TEXTGRID:
# Interval header
fd.write(tab + f"item [{tierNum + 1}]:\n")
fd.write(tab * 2 + f'class = "{tier.tierType}" \n')
fd.write(tab * 2 + f'name = "{tier_name}" \n')
fd.write(tab * 2 + f"xmin = {self.minTimestamp} \n")
fd.write(tab * 2 + f"xmax = {self.maxTimestamp} \n")
fd.write(tab * 2 + f"intervals: size = {len(tier._entries)} \n")
elif format == TextgridFormats.SHORT_TEXTGRID:
fd.write(f'"{tier.tierType}"\n')
fd.write(f'"{tier_name}"\n')
fd.write(
f"{self.minTimestamp}\n{self.maxTimestamp}\n{len(tier._entries)}\n"
)
for i, entry in enumerate(tier._entries):
start, end, label = entry
label = tgio_utils.escapeQuotes(label)
if format == TextgridFormats.LONG_TEXTGRID:
fd.write(
f"{tab * 2}intervals [{i + 1}]:\n"
f"{tab * 3}xmin = {start} \n"
f"{tab * 3}xmax = {end} \n"
f'{tab * 3}text = "{label}" \n'
)
elif format == TextgridFormats.SHORT_TEXTGRID:
fd.write(f'{start}\n{end}\n"{label}"\n')
[docs]
def process_ctm_line(
line: str, reversed_phone_mapping: Dict[int, int], raw_id=False
) -> typing.Tuple[int, CtmInterval]:
"""
Helper function for parsing a line of CTM file to construct a CTMInterval
CTM format is:
utt_id channel_num start_time phone_dur phone_id [confidence]
Parameters
----------
line: str
Input string
reversed_phone_mapping: dict[int, str]
Mapping from integer IDs to phone labels
Returns
-------
:class:`~kalpy.gmm.data.CtmInterval`
Extracted data from the line
"""
line = line.split()
utt = line[0]
if not raw_id:
utt = int(line[0].split("-")[-1])
begin = round(float(line[2]), 4)
duration = float(line[3])
end = round(begin + duration, 4)
label = line[4]
conf = None
if len(line) > 5:
conf = round(float(line[5]), 4)
label = reversed_phone_mapping[int(label)]
return utt, CtmInterval(begin, end, label, confidence=conf)
[docs]
def output_textgrid_writing_errors(
output_directory: str, export_errors: Dict[str, AlignmentExportError]
) -> None:
"""
Output any errors that were encountered in writing TextGrids
Parameters
----------
output_directory: str
Directory to save TextGrids files
export_errors: dict[str, :class:`~montreal_forced_aligner.exceptions.AlignmentExportError`]
Dictionary of errors encountered
"""
error_log = os.path.join(output_directory, "output_errors.txt")
if os.path.exists(error_log):
os.remove(error_log)
for result in export_errors.values():
if not os.path.exists(error_log):
with mfa_open(error_log, "w") as f:
f.write(
"The following exceptions were encountered during the output of the alignments to TextGrids:\n\n"
)
with mfa_open(error_log, "a") as f:
f.write(f"{str(result)}\n\n")
def parse_aligned_textgrid(
path: Path, root_speaker: typing.Optional[str] = None
) -> Dict[str, List[CtmInterval]]:
"""
Load a TextGrid as a dictionary of speaker's phone tiers
Parameters
----------
path: :class:`~pathlib.Path`
TextGrid file to parse
root_speaker: str, optional
Optional speaker if the TextGrid has no speaker information
Returns
-------
dict[str, list[:class:`~kalpy.gmm.data.CtmInterval`]]
Parsed phone tier
"""
tg = tgio.openTextgrid(path, includeEmptyIntervals=False, reportingMode="silence")
data = {}
num_tiers = len(tg.tiers)
if num_tiers == 0:
raise TextGridParseError(path, "Number of tiers parsed was zero")
phone_tier_pattern = re.compile(r"(.*) ?- ?phones")
for tier_name in tg.tierNames:
ti = tg._tierDict[tier_name]
if not isinstance(ti, tgio.IntervalTier):
continue
if "phones" not in tier_name:
continue
m = phone_tier_pattern.match(tier_name)
if m:
speaker_name = m.groups()[0].strip()
elif root_speaker:
speaker_name = root_speaker
else:
speaker_name = ""
if speaker_name not in data:
data[speaker_name] = []
for begin, end, text in ti.entries:
text = text.strip()
if not text:
continue
begin, end = round(begin, 4), round(end, 4)
if end - begin < 0.01:
continue
interval = CtmInterval(begin, end, text)
data[speaker_name].append(interval)
return data
def construct_textgrid_output(
session: Session,
file_batch: typing.Dict[int, typing.Tuple],
workflow: CorpusWorkflow,
cleanup_textgrids: bool,
clitic_marker: str,
output_directory: Path,
frame_shift: float,
output_format: str = TextgridFormats.SHORT_TEXTGRID,
include_original_text: bool = False,
):
phone_interval_query = (
sqlalchemy.select(
PhoneInterval.begin, PhoneInterval.end, Phone.phone, Speaker.name, Utterance.file_id
)
.execution_options(yield_per=1000)
.join(PhoneInterval.phone)
.join(PhoneInterval.utterance)
.join(Utterance.speaker)
.filter(PhoneInterval.duration > 0)
.filter(Utterance.file_id.in_(list(file_batch.keys())))
)
word_interval_query = (
sqlalchemy.select(
WordInterval.begin, WordInterval.end, Word.word, Speaker.name, Utterance.file_id
)
.execution_options(yield_per=1000)
.join(WordInterval.word)
.join(WordInterval.utterance)
.join(Utterance.speaker)
.filter(WordInterval.duration > 0)
.filter(Utterance.file_id.in_(list(file_batch.keys())))
)
if cleanup_textgrids:
phone_interval_query = phone_interval_query.filter(Phone.phone_type != PhoneType.silence)
word_interval_query = word_interval_query.filter(Word.word_type != WordType.silence)
phone_intervals = session.execute(
phone_interval_query.order_by(Utterance.file_id, PhoneInterval.begin)
)
word_intervals = session.execute(
word_interval_query.order_by(Utterance.file_id, WordInterval.begin)
)
utterances = None
if include_original_text:
utterances = session.execute(
sqlalchemy.select(
Utterance.begin, Utterance.end, Utterance.text, Speaker.name, Utterance.file_id
)
.execution_options(yield_per=1000)
.join(Utterance.speaker)
.filter(Utterance.file_id.in_(list(file_batch.keys())))
.order_by(Utterance.file_id)
)
pi_current_file_id = None
wi_current_file_id = None
u_current_file_id = None
word_data = []
phone_data = []
utterance_data = []
def process_phone_data():
for beg, end, p, speaker_name in phone_data:
if speaker_name not in data:
data[speaker_name] = {"words": [], "phones": []}
if include_original_text:
data[speaker_name]["utterances"] = []
data[speaker_name]["phones"].append(CtmInterval(beg, end, p))
def process_word_data():
for beg, end, w, speaker_name in word_data:
if (
cleanup_textgrids
and data[speaker_name]["words"]
and beg - data[speaker_name]["words"][-1].end < 0.02
and clitic_marker
and (
data[speaker_name]["words"][-1].label.endswith(clitic_marker)
or w.startswith(clitic_marker)
)
):
data[speaker_name]["words"][-1].end = end
data[speaker_name]["words"][-1].label += w
else:
data[speaker_name]["words"].append(CtmInterval(beg, end, w))
def process_utterance_data():
for beg, end, u, speaker_name in utterance_data:
data[speaker_name]["utterances"].append(CtmInterval(beg, end, u))
while True:
data = {}
for pi_begin, pi_end, phone, pi_speaker_name, pi_file_id in phone_intervals:
if pi_current_file_id is None:
pi_current_file_id = pi_file_id
if pi_file_id != pi_current_file_id:
process_phone_data()
phone_data = [(pi_begin, pi_end, phone, pi_speaker_name)]
current_file_id = pi_current_file_id
pi_current_file_id = pi_file_id
break
phone_data.append((pi_begin, pi_end, phone, pi_speaker_name))
else:
if phone_data:
process_phone_data()
current_file_id = pi_current_file_id
phone_data = []
else:
break
for wi_begin, wi_end, word, wi_speaker_name, wi_file_id in word_intervals:
if wi_current_file_id is None:
wi_current_file_id = wi_file_id
if wi_file_id != wi_current_file_id:
process_word_data()
word_data = [(wi_begin, wi_end, word, wi_speaker_name)]
wi_current_file_id = wi_file_id
break
word_data.append((wi_begin, wi_end, word, wi_speaker_name))
else:
if word_data:
process_word_data()
if include_original_text:
for u_begin, u_end, text, u_speaker_name, u_file_id in utterances:
if u_current_file_id is None:
u_current_file_id = u_file_id
if u_file_id != u_current_file_id:
process_utterance_data()
utterance_data = [(u_begin, u_end, text, u_speaker_name)]
u_current_file_id = u_file_id
break
utterance_data.append((u_begin, u_end, text, u_speaker_name))
else:
if utterance_data:
process_utterance_data()
file_name, relative_path, file_duration, text_file_path = file_batch[current_file_id]
output_path = construct_output_path(
file_name, relative_path, output_directory, text_file_path, output_format
)
export_textgrid(data, output_path, file_duration, frame_shift, output_format)
yield output_path
# word_data = []
# phone_data = []
# utterance_data = []
[docs]
def construct_output_path(
name: str,
relative_path: Path,
output_directory: Path,
input_path: Path = None,
output_format: str = TextgridFormats.SHORT_TEXTGRID,
) -> Path:
"""
Construct an output path
Returns
-------
Path
Output path
"""
if isinstance(output_directory, str):
output_directory = Path(output_directory)
if output_format.upper() == "LAB":
extension = ".lab"
elif output_format.upper() == "JSON":
extension = ".json"
elif output_format.upper() == "CSV":
extension = ".csv"
else:
extension = ".TextGrid"
if relative_path:
relative = output_directory.joinpath(relative_path)
else:
relative = output_directory
output_path = relative.joinpath(name + extension)
if output_path == input_path:
output_path = relative.joinpath(name + "_aligned" + extension)
os.makedirs(relative, exist_ok=True)
relative.mkdir(parents=True, exist_ok=True)
return output_path
[docs]
def export_textgrid(
speaker_data: Dict[str, Dict[str, List[CtmInterval]]],
output_path: Path,
duration: float,
frame_shift: float,
output_format: str = TextFileType.TEXTGRID.value,
) -> None:
"""
Export aligned file to TextGrid
Parameters
----------
speaker_data: dict[Speaker, dict[str, list[:class:`~kalpy.gmm.data.CtmInterval`]]
Per speaker, per word/phone :class:`~kalpy.gmm.data.CtmInterval`
output_path: :class:`~pathlib.Path`
Output path of the file
duration: float
Duration of the file
frame_shift: float
Frame shift of features, in seconds
output_format: str, optional
Output format, one of: "long_textgrid" (default), "short_textgrid", "json", or "csv"
"""
has_data = False
duration = round(duration, 6)
if output_format == "csv":
csv_data = []
for speaker, data in speaker_data.items():
for annotation_type, intervals in data.items():
if len(intervals):
has_data = True
for a in intervals:
if duration - a.end < (frame_shift * 2): # Fix rounding issues
a.end = duration
csv_data.append(
{
"Begin": a.begin,
"End": a.end,
"Label": a.label,
"Type": annotation_type,
"Speaker": speaker,
}
)
if has_data:
with mfa_open(output_path, "w") as f:
writer = csv.DictWriter(f, fieldnames=["Begin", "End", "Label", "Type", "Speaker"])
writer.writeheader()
for line in csv_data:
writer.writerow(line)
elif output_format == "json":
json_data = {"start": 0, "end": duration, "tiers": {}}
for speaker, data in speaker_data.items():
for annotation_type, intervals in data.items():
if len(speaker_data) > 1:
tier_name = f"{speaker} - {annotation_type}"
else:
tier_name = annotation_type
if tier_name not in json_data["tiers"]:
json_data["tiers"][tier_name] = {"type": "interval", "entries": []}
if len(intervals):
has_data = True
for a in intervals:
if duration - a.end < (frame_shift * 2): # Fix rounding issues
a.end = duration
json_data["tiers"][tier_name]["entries"].append([a.begin, a.end, a.label])
if has_data:
with mfa_open(output_path, "w") as f:
json.dump(json_data, f, indent=4, ensure_ascii=False)
else:
# Create initial textgrid
tg = Textgrid()
tg.minTimestamp = 0
tg.maxTimestamp = duration
for speaker, data in speaker_data.items():
for annotation_type, intervals in data.items():
if len(intervals):
has_data = True
if len(speaker_data) > 1:
tier_name = f"{speaker} - {annotation_type}"
else:
tier_name = annotation_type
if tier_name not in tg.tierNames:
tg.addTier(tgio.IntervalTier(tier_name, [], minT=0, maxT=duration))
tier = tg.getTier(tier_name)
for i, a in enumerate(sorted(intervals, key=lambda x: x.begin)):
if i == len(intervals) - 1 and duration - a.end < (
frame_shift * 2
): # Fix rounding issues
a.end = duration
tg_interval = to_tg_interval(a, duration)
if i > 0 and tier._entries[-1].end > tg_interval.start:
a.begin = tier._entries[-1].end
tg_interval = to_tg_interval(a, duration)
tier._entries.append(tg_interval)
if has_data:
for tier in tg.tiers:
if len(tier._entries) > 0 and tier._entries[-1][1] > tg.maxTimestamp:
tier.insertEntry(
Interval(
tier._entries[-1].start, tg.maxTimestamp, tier._entries[-1].label
),
collisionMode="replace",
)
tg.save(
str(output_path),
includeBlankSpaces=True,
format=output_format,
minimumIntervalLength=None,
reportingMode="silence",
)