"""
Textgrid utilities
==================
"""
from __future__ import annotations
import csv
import json
import os
import re
import typing
from pathlib import Path
from typing import Dict, List
import sqlalchemy
from praatio import textgrid as tgio
from praatio.data_classes.interval_tier import Interval
from praatio.utilities import utils as tgio_utils
from sqlalchemy.orm import Session
from montreal_forced_aligner.data import (
CtmInterval,
PhoneType,
TextFileType,
TextgridFormats,
WordType,
)
from montreal_forced_aligner.db import (
CorpusWorkflow,
Phone,
PhoneInterval,
Speaker,
Utterance,
Word,
WordInterval,
)
from montreal_forced_aligner.exceptions import AlignmentExportError, TextGridParseError
from montreal_forced_aligner.helper import mfa_open
__all__ = [
"process_ctm_line",
"export_textgrid",
"construct_textgrid_output",
"construct_output_path",
"output_textgrid_writing_errors",
]
class Textgrid(tgio.Textgrid):
def save(
self,
fn: str,
format: typing.Literal["short_textgrid", "long_textgrid", "json", "textgrid_json"],
includeBlankSpaces: bool,
minTimestamp: typing.Optional[float] = None,
maxTimestamp: typing.Optional[float] = None,
minimumIntervalLength: float = None,
reportingMode: typing.Literal["silence", "warning", "error"] = "warning",
) -> None:
"""Save the current textgrid to a file
Args:
fn: the fullpath filename of the output
format: one of ['short_textgrid', 'long_textgrid', 'json', 'textgrid_json']
'short_textgrid' and 'long_textgrid' are both used by praat
'json' and 'textgrid_json' are two json variants. 'json' cannot represent
tiers with different min and max timestamps than the textgrid.
includeBlankSpaces: if True, blank sections in interval
tiers will be filled in with an empty interval
(with a label of ""). If you are unsure, True is recommended
as Praat needs blanks to render textgrids properly.
minTimestamp: the minTimestamp of the saved Textgrid;
if None, use whatever is defined in the Textgrid object.
If minTimestamp is larger than timestamps in your textgrid,
an exception will be thrown.
maxTimestamp: the maxTimestamp of the saved Textgrid;
if None, use whatever is defined in the Textgrid object.
If maxTimestamp is smaller than timestamps in your textgrid,
an exception will be thrown.
minimumIntervalLength: any labeled intervals smaller
than this will be removed, useful for removing ultrashort
or fragmented intervals; if None, don't remove any.
Removed intervals are merged (without their label) into
adjacent entries.
reportingMode: one of "silence", "warning", or "error". This flag
determines the behavior if there is a size difference between the
maxTimestamp in the tier and the current textgrid.
Returns:
a string representation of the textgrid
"""
tab = " " * 4
with mfa_open(fn, mode="w") as fd:
if format in {TextgridFormats.LONG_TEXTGRID, TextgridFormats.SHORT_TEXTGRID}:
# Header
if format == TextgridFormats.LONG_TEXTGRID:
fd.write('File type = "ooTextFile"\n')
fd.write('Object class = "TextGrid"\n\n')
fd.write(f"xmin = {self.minTimestamp} \n")
fd.write(f"xmax = {self.maxTimestamp} \n")
fd.write("tiers? <exists> \n")
fd.write(f"size = {len(self._tierDict)} \n")
fd.write("item []: \n")
elif format == TextgridFormats.SHORT_TEXTGRID:
fd.write('File type = "ooTextFile"\n')
fd.write('Object class = "TextGrid"\n\n')
fd.write(f"{self.minTimestamp}\n{self.maxTimestamp}\n")
fd.write(f"<exists>\n{len(self._tierDict)}\n")
for tierNum, (name, tier) in enumerate(self._tierDict.items()):
if includeBlankSpaces and tier._entries:
if tier._entries[0][0] > 0.001:
tier._entries.insert(0, Interval(0.0, tier._entries[0][0], ""))
interval_index = 1
while interval_index < len(tier._entries):
start, end, label = tier._entries[interval_index]
previous_entry = tier._entries[interval_index - 1]
if start - previous_entry[1] > 0.001:
tier._entries.insert(
interval_index, Interval(previous_entry[1], start, "")
)
interval_index += 1
interval_index += 1
if self.maxTimestamp - tier._entries[-1][1] > 0.001:
tier._entries.append(
Interval(tier._entries[-1][1], self.maxTimestamp, "")
)
tier_name = tgio_utils.escapeQuotes(name)
if format == TextgridFormats.LONG_TEXTGRID:
# Interval header
fd.write(tab + f"item [{tierNum + 1}]:\n")
fd.write(tab * 2 + f'class = "{tier.tierType}" \n')
fd.write(tab * 2 + f'name = "{tier_name}" \n')
fd.write(tab * 2 + f"xmin = {self.minTimestamp} \n")
fd.write(tab * 2 + f"xmax = {self.maxTimestamp} \n")
fd.write(tab * 2 + f"intervals: size = {len(tier._entries)} \n")
elif format == TextgridFormats.SHORT_TEXTGRID:
fd.write(f'"{tier.tierType}"\n')
fd.write(f'"{tier_name}"\n')
fd.write(
f"{self.minTimestamp}\n{self.maxTimestamp}\n{len(tier._entries)}\n"
)
for i, entry in enumerate(tier._entries):
start, end, label = entry
label = tgio_utils.escapeQuotes(label)
if format == TextgridFormats.LONG_TEXTGRID:
fd.write(
f"{tab * 2}intervals [{i + 1}]:\n"
f"{tab * 3}xmin = {start} \n"
f"{tab * 3}xmax = {end} \n"
f'{tab * 3}text = "{label}" \n'
)
elif format == TextgridFormats.SHORT_TEXTGRID:
fd.write(f'{start}\n{end}\n"{label}"\n')
[docs]
def process_ctm_line(
line: str, reversed_phone_mapping: Dict[int, int], raw_id=False
) -> typing.Tuple[int, CtmInterval]:
"""
Helper function for parsing a line of CTM file to construct a CTMInterval
CTM format is:
utt_id channel_num start_time phone_dur phone_id [confidence]
Parameters
----------
line: str
Input string
reversed_phone_mapping: dict[int, str]
Mapping from integer IDs to phone labels
Returns
-------
:class:`~montreal_forced_aligner.data.CtmInterval`
Extracted data from the line
"""
line = line.split()
utt = line[0]
if not raw_id:
utt = int(line[0].split("-")[-1])
begin = round(float(line[2]), 4)
duration = float(line[3])
end = round(begin + duration, 4)
label = line[4]
conf = None
if len(line) > 5:
conf = round(float(line[5]), 4)
label = reversed_phone_mapping[int(label)]
return utt, CtmInterval(begin, end, label, confidence=conf)
[docs]
def output_textgrid_writing_errors(
output_directory: str, export_errors: Dict[str, AlignmentExportError]
) -> None:
"""
Output any errors that were encountered in writing TextGrids
Parameters
----------
output_directory: str
Directory to save TextGrids files
export_errors: dict[str, :class:`~montreal_forced_aligner.exceptions.AlignmentExportError`]
Dictionary of errors encountered
"""
error_log = os.path.join(output_directory, "output_errors.txt")
if os.path.exists(error_log):
os.remove(error_log)
for result in export_errors.values():
if not os.path.exists(error_log):
with mfa_open(error_log, "w") as f:
f.write(
"The following exceptions were encountered during the output of the alignments to TextGrids:\n\n"
)
with mfa_open(error_log, "a") as f:
f.write(f"{str(result)}\n\n")
def parse_aligned_textgrid(
path: Path, root_speaker: typing.Optional[str] = None
) -> Dict[str, List[CtmInterval]]:
"""
Load a TextGrid as a dictionary of speaker's phone tiers
Parameters
----------
path: :class:`~pathlib.Path`
TextGrid file to parse
root_speaker: str, optional
Optional speaker if the TextGrid has no speaker information
Returns
-------
dict[str, list[:class:`~montreal_forced_aligner.data.CtmInterval`]]
Parsed phone tier
"""
tg = tgio.openTextgrid(path, includeEmptyIntervals=False, reportingMode="silence")
data = {}
num_tiers = len(tg.tiers)
if num_tiers == 0:
raise TextGridParseError(path, "Number of tiers parsed was zero")
phone_tier_pattern = re.compile(r"(.*) ?- ?phones")
for tier_name in tg.tierNames:
ti = tg._tierDict[tier_name]
if not isinstance(ti, tgio.IntervalTier):
continue
if "phones" not in tier_name:
continue
m = phone_tier_pattern.match(tier_name)
if m:
speaker_name = m.groups()[0].strip()
elif root_speaker:
speaker_name = root_speaker
else:
speaker_name = ""
if speaker_name not in data:
data[speaker_name] = []
for begin, end, text in ti.entries:
text = text.lower().strip()
if not text:
continue
begin, end = round(begin, 4), round(end, 4)
if end - begin < 0.01:
continue
interval = CtmInterval(begin, end, text)
data[speaker_name].append(interval)
return data
def construct_textgrid_output(
session: Session,
file_batch: typing.Dict[int, typing.Tuple],
workflow: CorpusWorkflow,
cleanup_textgrids: bool,
clitic_marker: str,
output_directory: Path,
frame_shift: float,
output_format: str = TextgridFormats.SHORT_TEXTGRID,
include_original_text: bool = False,
):
phone_interval_query = (
sqlalchemy.select(
PhoneInterval.begin, PhoneInterval.end, Phone.phone, Speaker.name, Utterance.file_id
)
.execution_options(yield_per=1000)
.join(PhoneInterval.phone)
.join(PhoneInterval.utterance)
.join(Utterance.speaker)
.filter(PhoneInterval.workflow_id == workflow.id)
.filter(PhoneInterval.duration > 0)
.filter(Utterance.file_id.in_(list(file_batch.keys())))
)
word_interval_query = (
sqlalchemy.select(
WordInterval.begin, WordInterval.end, Word.word, Speaker.name, Utterance.file_id
)
.execution_options(yield_per=1000)
.join(WordInterval.word)
.join(WordInterval.utterance)
.join(Utterance.speaker)
.filter(WordInterval.workflow_id == workflow.id)
.filter(WordInterval.duration > 0)
.filter(Utterance.file_id.in_(list(file_batch.keys())))
)
if cleanup_textgrids:
phone_interval_query = phone_interval_query.filter(Phone.phone_type != PhoneType.silence)
word_interval_query = word_interval_query.filter(Word.word_type != WordType.silence)
phone_intervals = session.execute(
phone_interval_query.order_by(Utterance.file_id, PhoneInterval.begin)
)
word_intervals = session.execute(
word_interval_query.order_by(Utterance.file_id, WordInterval.begin)
)
utterances = None
if include_original_text:
utterances = session.execute(
sqlalchemy.select(
Utterance.begin, Utterance.end, Utterance.text, Speaker.name, Utterance.file_id
)
.execution_options(yield_per=1000)
.join(Utterance.speaker)
.filter(Utterance.file_id.in_(list(file_batch.keys())))
.order_by(Utterance.file_id)
)
pi_current_file_id = None
wi_current_file_id = None
u_current_file_id = None
word_data = []
phone_data = []
utterance_data = []
def process_phone_data():
for beg, end, p, speaker_name in phone_data:
if speaker_name not in data:
data[speaker_name] = {"words": [], "phones": []}
if include_original_text:
data[speaker_name]["utterances"] = []
data[speaker_name]["phones"].append(CtmInterval(beg, end, p))
def process_word_data():
for beg, end, w, speaker_name in word_data:
if (
cleanup_textgrids
and data[speaker_name]["words"]
and beg - data[speaker_name]["words"][-1].end < 0.02
and clitic_marker
and (
data[speaker_name]["words"][-1].label.endswith(clitic_marker)
or w.startswith(clitic_marker)
)
):
data[speaker_name]["words"][-1].end = end
data[speaker_name]["words"][-1].label += w
else:
data[speaker_name]["words"].append(CtmInterval(beg, end, w))
def process_utterance_data():
for beg, end, u, speaker_name in utterance_data:
data[speaker_name]["utterances"].append(CtmInterval(beg, end, u))
while True:
data = {}
for pi_begin, pi_end, phone, pi_speaker_name, pi_file_id in phone_intervals:
if pi_current_file_id is None:
pi_current_file_id = pi_file_id
if pi_file_id != pi_current_file_id:
process_phone_data()
phone_data = [(pi_begin, pi_end, phone, pi_speaker_name)]
current_file_id = pi_current_file_id
pi_current_file_id = pi_file_id
break
phone_data.append((pi_begin, pi_end, phone, pi_speaker_name))
else:
if phone_data:
process_phone_data()
current_file_id = pi_current_file_id
phone_data = []
else:
break
for wi_begin, wi_end, word, wi_speaker_name, wi_file_id in word_intervals:
if wi_current_file_id is None:
wi_current_file_id = wi_file_id
if wi_file_id != wi_current_file_id:
process_word_data()
word_data = [(wi_begin, wi_end, word, wi_speaker_name)]
wi_current_file_id = wi_file_id
break
word_data.append((wi_begin, wi_end, word, wi_speaker_name))
else:
if word_data:
process_word_data()
if include_original_text:
for u_begin, u_end, text, u_speaker_name, u_file_id in utterances:
if u_current_file_id is None:
u_current_file_id = u_file_id
if u_file_id != u_current_file_id:
process_utterance_data()
utterance_data = [(u_begin, u_end, text, u_speaker_name)]
u_current_file_id = u_file_id
break
utterance_data.append((u_begin, u_end, text, u_speaker_name))
else:
if utterance_data:
process_utterance_data()
file_name, relative_path, file_duration, text_file_path = file_batch[current_file_id]
output_path = construct_output_path(
file_name, relative_path, output_directory, text_file_path, output_format
)
export_textgrid(data, output_path, file_duration, frame_shift, output_format)
yield output_path
# word_data = []
# phone_data = []
# utterance_data = []
[docs]
def construct_output_path(
name: str,
relative_path: Path,
output_directory: Path,
input_path: Path = None,
output_format: str = TextgridFormats.SHORT_TEXTGRID,
) -> Path:
"""
Construct an output path
Returns
-------
Path
Output path
"""
if isinstance(output_directory, str):
output_directory = Path(output_directory)
if output_format.upper() == "LAB":
extension = ".lab"
elif output_format.upper() == "JSON":
extension = ".json"
elif output_format.upper() == "CSV":
extension = ".csv"
else:
extension = ".TextGrid"
if relative_path:
relative = output_directory.joinpath(relative_path)
else:
relative = output_directory
output_path = relative.joinpath(name + extension)
if output_path == input_path:
output_path = relative.joinpath(name + "_aligned" + extension)
os.makedirs(relative, exist_ok=True)
relative.mkdir(parents=True, exist_ok=True)
return output_path
[docs]
def export_textgrid(
speaker_data: Dict[str, Dict[str, List[CtmInterval]]],
output_path: Path,
duration: float,
frame_shift: float,
output_format: str = TextFileType.TEXTGRID.value,
) -> None:
"""
Export aligned file to TextGrid
Parameters
----------
speaker_data: dict[Speaker, dict[str, list[:class:`~montreal_forced_aligner.data.CtmInterval`]]
Per speaker, per word/phone :class:`~montreal_forced_aligner.data.CtmInterval`
output_path: :class:`~pathlib.Path`
Output path of the file
duration: float
Duration of the file
frame_shift: float
Frame shift of features, in seconds
output_format: str, optional
Output format, one of: "long_textgrid" (default), "short_textgrid", "json", or "csv"
"""
has_data = False
duration = round(duration, 6)
if output_format == "csv":
csv_data = []
for speaker, data in speaker_data.items():
for annotation_type, intervals in data.items():
if len(intervals):
has_data = True
for a in intervals:
if duration - a.end < (frame_shift * 2): # Fix rounding issues
a.end = duration
csv_data.append(
{
"Begin": a.begin,
"End": a.end,
"Label": a.label,
"Type": annotation_type,
"Speaker": speaker,
}
)
if has_data:
with mfa_open(output_path, "w") as f:
writer = csv.DictWriter(f, fieldnames=["Begin", "End", "Label", "Type", "Speaker"])
writer.writeheader()
for line in csv_data:
writer.writerow(line)
elif output_format == "json":
json_data = {"start": 0, "end": duration, "tiers": {}}
for speaker, data in speaker_data.items():
for annotation_type, intervals in data.items():
if len(speaker_data) > 1:
tier_name = f"{speaker} - {annotation_type}"
else:
tier_name = annotation_type
if tier_name not in json_data["tiers"]:
json_data["tiers"][tier_name] = {"type": "interval", "entries": []}
if len(intervals):
has_data = True
for a in intervals:
if duration - a.end < (frame_shift * 2): # Fix rounding issues
a.end = duration
json_data["tiers"][tier_name]["entries"].append([a.begin, a.end, a.label])
if has_data:
with mfa_open(output_path, "w") as f:
json.dump(json_data, f, indent=4, ensure_ascii=False)
else:
# Create initial textgrid
tg = Textgrid()
tg.minTimestamp = 0
tg.maxTimestamp = duration
for speaker, data in speaker_data.items():
for annotation_type, intervals in data.items():
if len(intervals):
has_data = True
if len(speaker_data) > 1:
tier_name = f"{speaker} - {annotation_type}"
else:
tier_name = annotation_type
if tier_name not in tg.tierNames:
tg.addTier(tgio.IntervalTier(tier_name, [], minT=0, maxT=duration))
tier = tg.getTier(tier_name)
for i, a in enumerate(sorted(intervals, key=lambda x: x.begin)):
if i == len(intervals) - 1 and duration - a.end < (
frame_shift * 2
): # Fix rounding issues
a.end = duration
tg_interval = a.to_tg_interval()
if i > 0 and tier._entries[-1].end > tg_interval.start:
a.begin = tier._entries[-1].end
tg_interval = a.to_tg_interval()
tier._entries.append(tg_interval)
if has_data:
for tier in tg.tiers:
if len(tier._entries) > 0 and tier._entries[-1][1] > tg.maxTimestamp:
tier.insertEntry(
Interval(
tier._entries[-1].start, tg.maxTimestamp, tier._entries[-1].label
),
collisionMode="replace",
)
tg.save(
str(output_path),
includeBlankSpaces=True,
format=output_format,
minimumIntervalLength=None,
reportingMode="silence",
)