Source code for montreal_forced_aligner.dictionary.mixins

"""Mixins for dictionary parsing capabilities"""

from __future__ import annotations

import abc
import os
import re
import typing
from collections import Counter
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple

from montreal_forced_aligner.abc import DatabaseMixin
from montreal_forced_aligner.data import PhoneSetType, PhoneType, WordType
from montreal_forced_aligner.db import Phone, Word
from montreal_forced_aligner.helper import mfa_open

if TYPE_CHECKING:
    from montreal_forced_aligner.abc import MetaDict

DEFAULT_PUNCTUATION = list(r'、。।,?!!@<>→"”()“„–,.:;—¿?¡:)!\\&%#*~【】,…‥「」『』〝〟″⟨⟩♪・‹›«»~′$+=‘')

DEFAULT_WORD_BREAK_MARKERS = list(r'?!!(),,.:;¡¿?“„"”&~%#—…‥、。【】$+=〝〟″‹›«»・⟨⟩「」『』')

DEFAULT_QUOTE_MARKERS = list("“„\"”〝〟″「」『』‚ʻʿ‘′'")

DEFAULT_CLITIC_MARKERS = list("'’‘")
DEFAULT_COMPOUND_MARKERS = list("-/")
DEFAULT_BRACKETS = [("[", "]"), ("{", "}"), ("<", ">"), ("(", ")"), ("<", ">")]

__all__ = ["SanitizeFunction", "SplitWordsFunction", "DictionaryMixin", "TemporaryDictionaryMixin"]


[docs] class SanitizeFunction: """ Class for functions that sanitize text and strip punctuation Parameters ---------- punctuation: list[str] List of characters to treat as punctuation clitic_markers: list[str] Characters that mark clitics compound_markers: list[str] Characters that mark compound words brackets: list[tuple[str, str]] List of bracket sets to not strip from the ends of words ignore_case: bool Flag for whether all items should be converted to lower case, defaults to True quote_markers: list[str], optional Quotation markers to use when parsing text quote_markers: list[str], optional Quotation markers to use when parsing text word_break_markers: list[str], optional Word break markers to use when parsing text """ def __init__( self, clitic_marker: str, clitic_cleanup_regex: Optional[re.Pattern], clitic_quote_regex: Optional[re.Pattern], punctuation_regex: Optional[re.Pattern], word_break_regex: Optional[re.Pattern], bracket_regex: Optional[re.Pattern], bracket_sanitize_regex: Optional[re.Pattern], ignore_case: bool = True, ): self.clitic_marker = clitic_marker self.clitic_cleanup_regex = clitic_cleanup_regex self.clitic_quote_regex = clitic_quote_regex self.punctuation_regex = punctuation_regex self.word_break_regex = word_break_regex self.bracket_regex = bracket_regex self.bracket_sanitize_regex = bracket_sanitize_regex self.ignore_case = ignore_case def __call__(self, text) -> typing.Generator[str]: """ Sanitize text according to punctuation, quotes, and word break characters Parameters ---------- text: str Text to sanitize Returns ------- Generator[str] Sanitized form """ if self.ignore_case: text = text.lower() if self.bracket_regex: for word_object in self.bracket_regex.finditer(text): word = word_object.group(0) new_word = self.bracket_sanitize_regex.sub("_", word) text = text.replace(word, new_word) if self.clitic_cleanup_regex: text = self.clitic_cleanup_regex.sub(self.clitic_marker, text) if self.clitic_quote_regex is not None and self.clitic_marker in text: text = self.clitic_quote_regex.sub(r"\g<word>", text) words = self.word_break_regex.split(text) for w in words: if not w: continue if self.punctuation_regex is not None and self.punctuation_regex.match(w): continue if w: yield w
[docs] class SplitWordsFunction: """ Class for functions that splits words that have compound and clitic markers Parameters ---------- clitic_markers: list[str] Characters that mark clitics compound_markers: list[str] Characters that mark compound words clitic_set: set[str] Set of clitic words brackets: list[tuple[str, str], optional Character tuples to treat as full brackets around words words_mapping: dict[str, int] Mapping of words to integer IDs specials_set: set[str] Set of special words oov_word : str What to label words not in the dictionary, defaults to None """ def __init__( self, clitic_marker: str, initial_clitic_regex: Optional[re.Pattern], final_clitic_regex: Optional[re.Pattern], compound_regex: Optional[re.Pattern], non_speech_regexes: Dict[str, re.Pattern], oov_word: Optional[str] = None, word_mapping: Optional[Dict[str, int]] = None, grapheme_mapping: Optional[Dict[str, int]] = None, ): self.clitic_marker = clitic_marker self.compound_regex = compound_regex self.oov_word = oov_word self.specials_set = {self.oov_word, "<s>", "</s>"} if not word_mapping: word_mapping = None self.word_mapping = word_mapping if not grapheme_mapping: grapheme_mapping = None self.grapheme_mapping = grapheme_mapping self.compound_pattern = None self.clitic_pattern = None self.non_speech_regexes = non_speech_regexes self.initial_clitic_regex = initial_clitic_regex self.final_clitic_regex = final_clitic_regex self.has_initial = False self.has_final = False if self.initial_clitic_regex is not None: self.has_initial = True if self.final_clitic_regex is not None: self.has_final = True
[docs] def to_str(self, normalized_text: str) -> str: """ Convert normalized text to an integer ID Parameters ---------- normalized_text: Word to convert Returns ------- str Normalized string """ if normalized_text in self.specials_set: return self.oov_word for word, regex in self.non_speech_regexes.items(): if regex.match(normalized_text): return word return normalized_text
[docs] def split_clitics( self, item: str, ) -> List[str]: """ Split a word into subwords based on dictionary information Parameters ---------- item: str Word to split Returns ------- list[str] List of subwords """ split = [] if self.compound_regex is not None: s = self.compound_regex.split(item) else: s = [item] if self.word_mapping is None: return [item] clean_initial_quote_regex = re.compile("^'") clean_final_quote_regex = re.compile("'$") benefit = False for seg in s: if not seg: continue if not self.clitic_marker or self.clitic_marker not in seg: split.append(seg) if not benefit and seg in self.word_mapping: benefit = True continue elif seg.startswith(self.clitic_marker): if seg[1:] in self.word_mapping: split.append(seg[1:]) benefit = True continue elif seg.endswith(self.clitic_marker): if seg[:-1] in self.word_mapping: split.append(seg[:-1]) benefit = True continue initial_clitics = [] final_clitics = [] if self.has_initial: while True: clitic = self.initial_clitic_regex.match(seg) if clitic is None: break benefit = True initial_clitics.append(clitic.group(0)) seg = seg[clitic.end(0) :] if seg in self.word_mapping: break if self.has_final: while True: clitic = self.final_clitic_regex.search(seg) if clitic is None: break benefit = True final_clitics.append(clitic.group(0)) seg = seg[: clitic.start(0)] if seg in self.word_mapping: break final_clitics.reverse() split.extend([clean_initial_quote_regex.sub("", x) for x in initial_clitics]) seg = clean_final_quote_regex.sub("", clean_initial_quote_regex.sub("", seg)) if seg: split.append(seg) split.extend([clean_final_quote_regex.sub("", x) for x in final_clitics]) if not benefit and seg in self.word_mapping: benefit = True if not benefit: return [item] return split
def parse_graphemes( self, item: str, ) -> typing.Generator[str]: for word, regex in self.non_speech_regexes.items(): if regex.match(item): yield word break else: characters = list(item) for c in characters: if self.grapheme_mapping is not None and c in self.grapheme_mapping: yield c else: yield self.oov_word def __call__( self, item: str, ) -> List[str]: """ Return the list of sub words if necessary taking into account clitic and compound markers Parameters ---------- item: str Word to look up Returns ------- list[str] List of subwords that are in the dictionary """ if self.word_mapping is not None and item in self.word_mapping: return [item] for regex in self.non_speech_regexes.values(): if regex.match(item): return [item] return self.split_clitics(item)
[docs] class DictionaryMixin: """ Abstract class for MFA classes that use acoustic models Parameters ---------- oov_word : str What to label words not in the dictionary, defaults to ``'<unk>'`` position_dependent_phones : bool Specifies whether phones should be represented as dependent on their position in the word (beginning, middle or end), defaults to True num_silence_states : int Number of states to use for silence phones, defaults to 5 num_non_silence_states : int Number of states to use for non-silence phones, defaults to 3 shared_silence_phones : bool Specify whether to share states across all silence phones, defaults to False ignore_case: bool Flag for whether all items should be converted to lower case, defaults to True silence_probability : float Probability of optional silences following words, defaults to 0.5 initial_silence_probability : float Probability of initial silence, defaults to 0.5 final_silence_correction : float Correction term on final silence, defaults to None final_non_silence_correction : float Correction term on final non-silence, defaults to None punctuation: str, optional Punctuation to use when parsing text clitic_markers: str, optional Clitic markers to use when parsing text compound_markers: str, optional Compound markers to use when parsing text quote_markers: list[str], optional Quotation markers to use when parsing text word_break_markers: list[str], optional Word break markers to use when parsing text brackets: list[tuple[str, str], optional Character tuples to treat as full brackets around words clitic_set: set[str] Set of clitic words disambiguation_symbols: set[str] Set of disambiguation symbols max_disambiguation_symbol: int Maximum number of disambiguation symbols required, defaults to 0 preserve_suprasegmentals: int Flag for whether to keep phones separated by tone and stress base_phone_mapping: dict[str, str] Mapping between phone symbols to make them share a base root for decision trees """ positions: List[str] = ["_B", "_E", "_I", "_S"] def __init__( self, oov_word: str = "<unk>", silence_word: str = "<eps>", optional_silence_phone: str = "sil", oov_phone: str = "spn", other_noise_phone: Optional[str] = None, position_dependent_phones: bool = False, num_silence_states: int = 5, num_non_silence_states: int = 3, shared_silence_phones: bool = False, ignore_case: bool = True, silence_probability: float = 0.5, initial_silence_probability: float = 0.5, final_silence_correction: float = None, final_non_silence_correction: float = None, punctuation: List[str] = None, clitic_markers: List[str] = None, compound_markers: List[str] = None, quote_markers: List[str] = None, word_break_markers: List[str] = None, brackets: List[Tuple[str, str]] = None, non_silence_phones: Set[str] = None, disambiguation_symbols: Set[str] = None, clitic_set: Set[str] = None, max_disambiguation_symbol: int = 0, phone_set_type: typing.Union[str, PhoneSetType] = "UNKNOWN", preserve_suprasegmentals: bool = False, base_phone_mapping: Dict[str, str] = None, use_cutoff_model: bool = False, **kwargs, ): super().__init__(**kwargs) self.punctuation = DEFAULT_PUNCTUATION self.clitic_markers = DEFAULT_CLITIC_MARKERS self.compound_markers = DEFAULT_COMPOUND_MARKERS self.brackets = DEFAULT_BRACKETS self.quote_markers = DEFAULT_QUOTE_MARKERS self.word_break_markers = DEFAULT_WORD_BREAK_MARKERS if punctuation is not None: self.punctuation = punctuation if clitic_markers is not None: self.clitic_markers = clitic_markers self.clitic_marker = None if self.clitic_markers: self.clitic_marker = self.clitic_markers[0] if compound_markers is not None: self.compound_markers = compound_markers if brackets is not None: self.brackets = brackets if quote_markers is not None: self.quote_markers = quote_markers if word_break_markers is not None: self.word_break_markers = word_break_markers self.num_silence_states = num_silence_states self.num_non_silence_states = num_non_silence_states self.shared_silence_phones = shared_silence_phones self.silence_probability = silence_probability self.initial_silence_probability = initial_silence_probability self.final_silence_correction = final_silence_correction self.final_non_silence_correction = final_non_silence_correction self.ignore_case = ignore_case self.oov_word = oov_word self.silence_word = silence_word self.bracketed_word = "[bracketed]" self.cutoff_word = "<cutoff>" self.laughter_word = "[laughter]" self.position_dependent_phones = position_dependent_phones self.optional_silence_phone = optional_silence_phone self.other_noise_phone = other_noise_phone self.oov_phone = oov_phone self.oovs_found = Counter() if non_silence_phones is None: non_silence_phones = set() self.non_silence_phones = non_silence_phones self.excluded_phones = set() self.excluded_pronunciation_count = 0 self.max_disambiguation_symbol = max_disambiguation_symbol if disambiguation_symbols is None: disambiguation_symbols = set() self.disambiguation_symbols = disambiguation_symbols if clitic_set is None: clitic_set = set() self.clitic_set = clitic_set if phone_set_type is None: phone_set_type = "UNKNOWN" if not isinstance(phone_set_type, PhoneSetType): phone_set_type = PhoneSetType[phone_set_type] self.phone_set_type = phone_set_type self.preserve_suprasegmentals = preserve_suprasegmentals self.base_phone_mapping = base_phone_mapping self.punctuation_regex = None self.compound_regex = None self.bracket_regex = None self.laughter_regex = None self.word_break_regex = None self.bracket_sanitize_regex = None self.use_cutoff_model = use_cutoff_model self._phone_groups = {} @property def base_phones(self) -> Dict[str, Set[str]]: """Grouped phones by base phone""" base_phones = {} for p in self.non_silence_phones: b = self.get_base_phone(p) if b not in base_phones: base_phones[b] = set() base_phones[b].add(p) return base_phones
[docs] def get_base_phone(self, phone: str) -> str: """ Get the base phone, either through stripping diacritics, tone, and/or stress Parameters ---------- phone: str Phone used in pronunciation dictionary Returns ------- str Base phone """ if self.preserve_suprasegmentals and ( self is PhoneSetType.ARPA or self is PhoneSetType.PINYIN ): return phone elif self.preserve_suprasegmentals: pattern = self.phone_set_type.suprasegmental_phone_regex else: pattern = self.phone_set_type.base_phone_regex if self.phone_set_type.has_base_phone_regex: base_phone = pattern.sub("", phone) if self.base_phone_mapping and base_phone in self.base_phone_mapping: return self.base_phone_mapping[base_phone] return base_phone return phone
@property def extra_questions_mapping(self) -> Dict[str, List[str]]: """Mapping of extra questions for the given phone set type""" mapping = {"silence_question": []} for p in sorted(self.silence_phones): mapping["silence_question"].append(p) if self.position_dependent_phones: mapping["silence_question"].extend([p + x for x in self.positions]) for k, v in self.phone_set_type.extra_questions.items(): if k not in mapping: mapping[k] = [] if self.phone_set_type is PhoneSetType.ARPA: if self.position_dependent_phones: for x in sorted(v): mapping[k].extend([x + pos for pos in self.positions]) else: mapping[k] = sorted(v) elif self.phone_set_type is PhoneSetType.IPA: filtered_v = set() for x in self.non_silence_phones: base_phone = self.get_base_phone(x) if base_phone in v: filtered_v.add(x) if len(filtered_v) < 2: del mapping[k] continue if self.position_dependent_phones: for x in sorted(filtered_v): mapping[k].extend([x + pos for pos in self.positions]) else: mapping[k] = sorted(filtered_v) elif self.phone_set_type is PhoneSetType.PINYIN: filtered_v = set() for x in self.non_silence_phones: base_phone = self.get_base_phone(x) if base_phone in v or x in v: filtered_v.add(x) elif x in v: filtered_v.add(x) if len(filtered_v) < 2: del mapping[k] continue if self.position_dependent_phones: for x in sorted(filtered_v): mapping[k].extend([x + pos for pos in self.positions]) else: mapping[k] = sorted(filtered_v) if self.position_dependent_phones: phones = sorted(self.non_silence_phones) for pos in self.positions: mapping[f"non_silence{pos}"] = [x + pos for x in phones] silence_phones = sorted(self.silence_phones) for pos in [""] + self.positions: mapping[f"silence{pos}"] = [x + pos for x in silence_phones] return mapping @property def dictionary_options(self) -> MetaDict: """Dictionary options""" return { "punctuation": self.punctuation, "clitic_markers": self.clitic_markers, "clitic_set": self.clitic_set, "compound_markers": self.compound_markers, "brackets": self.brackets, "num_silence_states": self.num_silence_states, "num_non_silence_states": self.num_non_silence_states, "shared_silence_phones": self.shared_silence_phones, "silence_probability": self.silence_probability, "initial_silence_probability": self.initial_silence_probability, "final_silence_correction": self.final_silence_correction, "final_non_silence_correction": self.final_non_silence_correction, "oov_word": self.oov_word, "silence_word": self.silence_word, "position_dependent_phones": self.position_dependent_phones, "optional_silence_phone": self.optional_silence_phone, "oov_phone": self.oov_phone, "non_silence_phones": self.non_silence_phones, "max_disambiguation_symbol": self.max_disambiguation_symbol, "disambiguation_symbols": self.disambiguation_symbols, "phone_set_type": str(self.phone_set_type), } @property def silence_phones(self) -> Set[str]: """Silence phones""" if self.other_noise_phone is not None: return {self.optional_silence_phone, self.oov_phone, self.other_noise_phone} return { self.optional_silence_phone, self.oov_phone, } @property def context_independent_csl(self) -> str: """Context independent colon-separated list""" return ":".join(str(self.phone_mapping[x]) for x in self.kaldi_silence_phones) @property def specials_set(self) -> Set[str]: """Special words, like the ``oov_word`` ``silence_word``, ``<s>``, and ``</s>``""" return { self.silence_word, self.oov_word, self.bracketed_word, self.laughter_word, "<s>", "</s>", } @property def phone_mapping(self) -> Dict[str, int]: """Mapping of phones to integer IDs""" phone_mapping = {} i = 0 phone_mapping["<eps>"] = i for p in self.kaldi_silence_phones: i += 1 phone_mapping[p] = i for p in self.kaldi_non_silence_phones: i += 1 phone_mapping[p] = i i = max(phone_mapping.values()) for x in range(self.max_disambiguation_symbol + 2): p = f"#{x}" self.disambiguation_symbols.add(p) i += 1 phone_mapping[p] = i return phone_mapping @property def silence_disambiguation_symbol(self) -> str: """ Silence disambiguation symbol """ return f"#{self.max_disambiguation_symbol + 1}" @property def reversed_phone_mapping(self) -> Dict[int, str]: """ A mapping of integer ids to phones """ mapping = {} for k, v in self.phone_mapping.items(): mapping[v] = k return mapping @property def positional_silence_phones(self) -> List[str]: """ List of silence phones with positions """ silence_phones = [] for p in sorted(self.silence_phones): silence_phones.append(p) for pos in self.positions: silence_phones.append(p + pos) return silence_phones def _generate_positional_list(self, phones: Set[str]) -> List[str]: """ Helper function to generate positional list for phones along with any base phones for the phone set Parameters ---------- phones: set[str] Set of phones Returns ------- list[str] List of positional phones, sorted by base phone """ positional_phones = [] phones |= {self.get_base_phone(p) for p in phones} for p in sorted(phones): if p not in self.non_silence_phones: continue for pos in self.positions: pos_p = p + pos if pos_p not in positional_phones: positional_phones.append(pos_p) return positional_phones def _generate_non_positional_list(self, phones: Set[str]) -> List[str]: """ Helper function to generate non-positional list for phones with any base phones for the phone set Parameters ---------- phones: set[str] Set of phones Returns ------- list[str] List of non-positional phones, sorted by base phone """ base_phones = set() for p in phones: base_phone = self.get_base_phone(p) base_phones.add(base_phone) return sorted(phones | base_phones) def _generate_phone_list(self, phones: Set[str]) -> List[str]: """ Helper function to generate phone list Parameters ---------- phones: set[str] Set of phones Returns ------- list[str] List of positional or non-positional phones, sorted by base phone """ if self.position_dependent_phones: return self._generate_positional_list(phones) return self._generate_non_positional_list(phones) @property def positional_non_silence_phones(self) -> List[str]: """ List of non-silence phones with positions """ return self._generate_positional_list(self.non_silence_phones) @property def kaldi_non_silence_phones(self) -> List[str]: """Non silence phones in Kaldi format""" if self.position_dependent_phones: return self.positional_non_silence_phones return self._generate_non_positional_list(self.non_silence_phones) @property def phone_groups(self) -> typing.Dict[str, typing.List[str]]: if not self._phone_groups: for p in sorted(self.non_silence_phones): base_phone = self.get_base_phone(p) if base_phone not in self._phone_groups: self._phone_groups[base_phone] = [base_phone] if p not in self._phone_groups[base_phone]: self._phone_groups[base_phone].append(p) return self._phone_groups @property def kaldi_grouped_phones(self) -> Dict[str, List[str]]: """Non silence phones in Kaldi format""" groups = {} for k, v in self.phone_groups.items(): if self.position_dependent_phones: groups[k] = [x + pos for pos in self.positions for x in v] else: groups[k] = v return {k: v for k, v in groups.items() if v} @property def kaldi_silence_phones(self) -> List[str]: """Silence phones in Kaldi format""" if self.position_dependent_phones: return self.positional_silence_phones return sorted(self.silence_phones) @property def optional_silence_csl(self) -> str: """ Phone ID of the optional silence phone """ try: return str(self.phone_mapping[self.optional_silence_phone]) except Exception: return "" @property def silence_csl(self) -> str: """ A colon-separated string of silence phone ids """ return ":".join(map(str, (self.phone_mapping[x] for x in self.kaldi_silence_phones))) @property def non_silence_csl(self) -> str: """ A colon-separated string of non-silence phone ids """ return ":".join(map(str, (self.phone_mapping[x] for x in self.kaldi_non_silence_phones))) @property def phones(self) -> set: """ The set of all phones (silence and non-silence) """ return self.silence_phones | self.non_silence_phones
[docs] def check_bracketed(self, word: str) -> bool: """ Checks whether a given string is surrounded by brackets. Parameters ---------- word : str Text to check for final brackets Returns ------- bool True if the word is fully bracketed, false otherwise """ for b in self.brackets: if re.match(rf"^{re.escape(b[0])}.*{re.escape(b[1])}$", word): return True return False
[docs] class TemporaryDictionaryMixin(DictionaryMixin, DatabaseMixin, metaclass=abc.ABCMeta): """ Mixin for dictionaries backed by a temporary directory """ def __init__(self, **kwargs): super().__init__(**kwargs) self._disambiguation_symbols_int_path = None self._phones_dir = None self._lexicon_fst_paths = {} self._num_words = None self._num_speech_words = None @property def num_words(self) -> int: """Number of words (including OOVs and special symbols) in the dictionary""" if self._num_words is None: with self.session() as session: self._num_words = session.query(Word).count() return self._num_words @property def num_speech_words(self) -> int: """Number of speech words in the dictionary""" if self._num_speech_words is None: with self.session() as session: self._num_speech_words = ( session.query(Word) .filter(Word.word_type.in_([WordType.speech, WordType.clitic])) .count() ) return self._num_speech_words @property def word_boundary_int_path(self) -> Path: """Path to the word boundary integer IDs""" return self.dictionary_output_directory.joinpath("phones", "word_boundary.int") def _write_word_boundaries(self) -> None: """ Write the word boundaries file to the temporary directory """ boundary_path = os.path.join( self.dictionary_output_directory, "phones", "word_boundary.txt" ) with mfa_open(boundary_path, "w") as f, mfa_open(self.word_boundary_int_path, "w") as intf: if self.position_dependent_phones: for p in sorted(self.phone_mapping.keys(), key=lambda x: self.phone_mapping[x]): if p == "<eps>" or p.startswith("#"): continue cat = "nonword" if p.endswith("_B"): cat = "begin" elif p.endswith("_S"): cat = "singleton" elif p.endswith("_I"): cat = "internal" elif p.endswith("_E"): cat = "end" f.write(" ".join([p, cat]) + "\n") intf.write(" ".join([str(self.phone_mapping[p]), cat]) + "\n") def _get_grouped_phones(self) -> Dict[str, Set[str]]: """ Group phones for use in Kaldi processing Returns ------- dict[str, set[str]] Grouped phone by manner """ phones = { "stops": set(), "fricatives": set(), "affricates": set(), "liquids": set(), "nasals": set(), "monophthongs": set(), "diphthongs": set(), "triphthongs": set(), "other": set(), } for p in self.non_silence_phones: base_phone = self.get_base_phone(p) if base_phone in self.phone_set_type.stops: phones["stops"].add(p) elif base_phone in self.phone_set_type.affricates: phones["affricates"].add(p) elif base_phone in ( self.phone_set_type.laterals | self.phone_set_type.approximants | self.phone_set_type.nasal_approximants ): phones["liquids"].add(p) elif base_phone in ( self.phone_set_type.fricatives | self.phone_set_type.lateral_fricatives | self.phone_set_type.sibilants ): phones["fricatives"].add(p) elif base_phone in self.phone_set_type.vowels: phones["monophthongs"].add(p) elif base_phone in self.phone_set_type.diphthong_phones: phones["diphthongs"].add(p) elif base_phone in self.phone_set_type.triphthong_phones: phones["triphthongs"].add(p) else: phones["other"].add(p) return phones def _write_topo(self) -> None: """ Write the topo file to the temporary directory """ sil_transp = 1 / (self.num_silence_states - 1) silence_lines = [ "<TopologyEntry>", "<ForPhones>", " ".join(str(self.phone_mapping[x]) for x in self.kaldi_silence_phones), "</ForPhones>", ] for i in range(self.num_silence_states): if i == 0: # Initial silence state transition_string = " ".join( f"<Transition> {x} {sil_transp}" for x in range(self.num_silence_states - 1) ) silence_lines.append(f"<State> {i} <PdfClass> {i} {transition_string} </State>") elif i < self.num_silence_states - 1: # non-final silence states transition_string = " ".join( f"<Transition> {x} {sil_transp}" for x in range(1, self.num_silence_states) ) silence_lines.append(f"<State> {i} <PdfClass> {i} {transition_string} </State>") else: silence_lines.append( f"<State> {i} <PdfClass> {i} <Transition> {i} 0.75 <Transition> {self.num_silence_states} 0.25 </State>" ) silence_lines.append(f"<State> {self.num_silence_states} </State>") silence_lines.append("</TopologyEntry>") silence_topo_string = "\n".join(silence_lines) topo_sections = [silence_topo_string] topo_phones = self._get_grouped_phones() for phone_list in topo_phones.values(): if not phone_list: continue non_silence_lines = [ "<TopologyEntry>", "<ForPhones>", " ".join( str(self.phone_mapping[x]) for x in self._generate_phone_list(phone_list) ), "</ForPhones>", ] # num_states = state_mapping[phone_type] num_states = self.num_non_silence_states for i in range(num_states): if i == 0: # Initial non_silence state transition_probability = 1 / self.num_non_silence_states transition_string = " ".join( f"<Transition> {x} {transition_probability}" for x in range(1, self.num_non_silence_states + 1) ) non_silence_lines.append( f"<State> {i} <PdfClass> {i} {transition_string} </State>" ) elif i == num_states - 1: non_silence_lines.append( f"<State> {i} <PdfClass> {i} <Transition> {i+1} 1.0 </State>" ) else: non_silence_lines.append( f"<State> {i} <PdfClass> {i} <Transition> {i} 0.5 <Transition> {i+1} 0.5 </State>" ) non_silence_lines.append(f"<State> {num_states} </State>") non_silence_lines.append("</TopologyEntry>") non_silence_topo_string = "\n".join(non_silence_lines) topo_sections.append(non_silence_topo_string) with mfa_open(self.topo_path, "w") as f: f.write("<Topology>\n") for section in topo_sections: f.write(section + "\n\n") f.write("</Topology>\n") def _write_phone_sets(self) -> None: """ Write phone symbol sets to the temporary directory """ sets_file = self.dictionary_output_directory.joinpath("phones", "sets.txt") roots_file = self.dictionary_output_directory.joinpath("phones", "roots.txt") sets_int_file = self.dictionary_output_directory.joinpath("phones", "sets.int") roots_int_file = self.dictionary_output_directory.joinpath("phones", "roots.int") with mfa_open(sets_file, "w") as setf, mfa_open(roots_file, "w") as rootf, mfa_open( sets_int_file, "w" ) as setintf, mfa_open(roots_int_file, "w") as rootintf: # process silence phones if self.shared_silence_phones: phone_string = " ".join(self.kaldi_silence_phones) phone_int_string = " ".join( str(self.phone_mapping[x]) for x in self.kaldi_silence_phones ) setf.write(f"{phone_string}\n") setintf.write(f"{phone_int_string}\n") rootf.write(f"not-shared not-split {phone_string}\n") rootintf.write(f"not-shared not-split {phone_int_string}\n") else: for sp in self.silence_phones: if self.position_dependent_phones: mapped = [sp + x for x in [""] + self.positions] else: mapped = [sp] phone_string = " ".join(mapped) phone_int_string = " ".join(str(self.phone_mapping[x]) for x in mapped) setf.write(f"{phone_string}\n") setintf.write(f"{phone_int_string}\n") rootf.write(f"shared split {phone_string}\n") rootintf.write(f"shared split {phone_int_string}\n") # process nonsilence phones for group in self.kaldi_grouped_phones.values(): group = sorted(group, key=lambda x: self.phone_mapping[x]) phone_string = " ".join(group) phone_int_string = " ".join(str(self.phone_mapping[x]) for x in group) setf.write(f"{phone_string}\n") setintf.write(f"{phone_int_string}\n") rootf.write(f"shared split {phone_string}\n") rootintf.write(f"shared split {phone_int_string}\n") @property def phone_symbol_table_path(self) -> Path: """Path to file containing phone symbols and their integer IDs""" return self.phones_dir.joinpath("phones.txt") @property def grapheme_symbol_table_path(self) -> Path: """Path to file containing grapheme symbols and their integer IDs""" return self.phones_dir.joinpath("graphemes.txt") @property def disambiguation_symbols_txt_path(self) -> Path: """Path to the file containing phone disambiguation symbols""" return self.phones_dir.joinpath("disambiguation_symbols.txt") @property def disambiguation_symbols_int_path(self) -> Path: """Path to the file containing integer IDs for phone disambiguation symbols""" if self._disambiguation_symbols_int_path is None: self._disambiguation_symbols_int_path = self.phones_dir.joinpath( "disambiguation_symbols.int" ) return self._disambiguation_symbols_int_path @property def phones_dir(self) -> Path: """Directory for storing phone information""" if self._phones_dir is None: self._phones_dir = self.dictionary_output_directory.joinpath("phones") return self._phones_dir @property def topo_path(self) -> Path: """Path to the dictionary's topology file""" return self.phones_dir.joinpath("topo") def _write_disambig(self) -> None: """ Write disambiguation symbols to the temporary directory """ disambig = self.disambiguation_symbols_txt_path disambig_int = self.disambiguation_symbols_int_path with self.session() as session, mfa_open(disambig, "w") as outf, mfa_open( disambig_int, "w" ) as intf: disambiguation_symbols = session.query(Phone.mapping_id, Phone.kaldi_label).filter( Phone.phone_type == PhoneType.disambiguation ) for p_id, p in disambiguation_symbols: outf.write(f"{p}\n") intf.write(f"{p_id}\n") phone_disambig_path = self.phones_dir.joinpath("phone_disambig.txt") with mfa_open(phone_disambig_path, "w") as f: f.write(str(self.phone_mapping["#0"]))