Source code for montreal_forced_aligner.vad.multiprocessing

"""Multiprocessing functionality for VAD"""
from __future__ import annotations

import logging
import os
import re
import subprocess
import typing
from pathlib import Path
from typing import TYPE_CHECKING, List, Union

import librosa
import numpy as np
import pynini
import pywrapfst
import sqlalchemy
from Bio import pairwise2

from montreal_forced_aligner.abc import KaldiFunction
from montreal_forced_aligner.corpus.features import online_feature_proc
from montreal_forced_aligner.data import CtmInterval, MfaArguments, WordType
from montreal_forced_aligner.db import Dictionary, File, SoundFile, Speaker, Utterance, Word
from montreal_forced_aligner.exceptions import KaldiProcessingError
from montreal_forced_aligner.helper import mfa_open
from montreal_forced_aligner.utils import parse_ctm_output, read_feats, thirdparty_binary

try:
    import warnings

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        torch_logger = logging.getLogger("speechbrain.utils.torch_audio_backend")
        torch_logger.setLevel(logging.ERROR)
        torch_logger = logging.getLogger("speechbrain.utils.train_logger")
        torch_logger.setLevel(logging.ERROR)
        import torch
        from speechbrain.pretrained import VAD

    FOUND_SPEECHBRAIN = True
except (ImportError, OSError):
    FOUND_SPEECHBRAIN = False
    VAD = None

if TYPE_CHECKING:
    SpeakerCharacterType = Union[str, int]
    from dataclasses import dataclass

    from montreal_forced_aligner.abc import MetaDict
else:
    from dataclassy import dataclass


[docs] @dataclass class SegmentVadArguments(MfaArguments): """Arguments for :class:`~montreal_forced_aligner.segmenter.SegmentVadFunction`""" vad_path: Path segmentation_options: MetaDict
[docs] def get_initial_segmentation( frames: List[Union[int, str]], frame_shift: float ) -> List[CtmInterval]: """ Compute initial segmentation over voice activity Parameters ---------- frames: list[Union[int, str]] List of frames with VAD output frame_shift: float Frame shift of features in seconds Returns ------- List[CtmInterval] Initial segmentation """ segments = [] cur_segment = None silent_frames = 0 non_silent_frames = 0 for i, f in enumerate(frames): if int(f) > 0: non_silent_frames += 1 if cur_segment is None: cur_segment = CtmInterval(begin=i * frame_shift, end=0, label="speech") else: silent_frames += 1 if cur_segment is not None: cur_segment.end = (i - 1) * frame_shift segments.append(cur_segment) cur_segment = None if cur_segment is not None: cur_segment.end = len(frames) * frame_shift segments.append(cur_segment) return segments
[docs] def merge_segments( segments: List[CtmInterval], min_pause_duration: float, max_segment_length: float, min_segment_length: float, ) -> List[CtmInterval]: """ Merge segments together Parameters ---------- segments: SegmentationType Initial segments min_pause_duration: float Minimum amount of silence time to mark an utterance boundary max_segment_length: float Maximum length of segments before they're broken up min_segment_length: float Minimum length of segments returned Returns ------- List[CtmInterval] Merged segments """ merged_segments = [] snap_boundary_threshold = min_pause_duration / 2 for s in segments: if ( not merged_segments or s.begin > merged_segments[-1].end + min_pause_duration or s.end - merged_segments[-1].begin > max_segment_length ): if s.end - s.begin > min_pause_duration: if merged_segments and snap_boundary_threshold: boundary_gap = s.begin - merged_segments[-1].end if boundary_gap < snap_boundary_threshold: half_boundary = boundary_gap / 2 else: half_boundary = snap_boundary_threshold / 2 merged_segments[-1].end += half_boundary s.begin -= half_boundary merged_segments.append(s) else: merged_segments[-1].end = s.end return [x for x in merged_segments if x.end - x.begin > min_segment_length]
def construct_utterance_segmentation_fst( text: str, word_symbol_table: pywrapfst.SymbolTable, interjection_words: typing.List[str] = None, ): if interjection_words is None: interjection_words = [] words = text.split() fst = pynini.Fst() start_state = fst.add_state() fst.set_start(start_state) fst.add_states(len(words)) for i, w in enumerate(words): next_state = i + 1 label = word_symbol_table.find(w) if i != 0: fst.add_arc( start_state, pywrapfst.Arc(label, label, pywrapfst.Weight.one(fst.weight_type()), next_state), ) fst.add_arc( i, pywrapfst.Arc(label, label, pywrapfst.Weight.one(fst.weight_type()), next_state) ) fst.set_final(next_state, pywrapfst.Weight(fst.weight_type(), 1)) for interjection in interjection_words: start_interjection_state = fst.add_state() fst.add_arc( next_state, pywrapfst.Arc( word_symbol_table.find("<eps>"), word_symbol_table.find("<eps>"), pywrapfst.Weight(fst.weight_type(), 10), start_interjection_state, ), ) if " " in interjection: i_words = interjection.split() for j, iw in enumerate(i_words): next_interjection_state = fst.add_state() if j == 0: prev_state = start_interjection_state else: prev_state = next_interjection_state - 1 label = word_symbol_table.find(iw) weight = pywrapfst.Weight.one(fst.weight_type()) fst.add_arc( prev_state, pywrapfst.Arc(label, label, weight, next_interjection_state) ) final_interjection_state = next_interjection_state else: final_interjection_state = fst.add_state() label = word_symbol_table.find(interjection) weight = pywrapfst.Weight.one(fst.weight_type()) fst.add_arc( start_interjection_state, pywrapfst.Arc(label, label, weight, final_interjection_state), ) # Path to next word in text weight = pywrapfst.Weight.one(fst.weight_type()) fst.add_arc( final_interjection_state, pywrapfst.Arc( word_symbol_table.find("<eps>"), word_symbol_table.find("<eps>"), weight, next_state, ), ) for interjection in interjection_words: start_interjection_state = fst.add_state() fst.add_arc( start_state, pywrapfst.Arc( word_symbol_table.find("<eps>"), word_symbol_table.find("<eps>"), pywrapfst.Weight(fst.weight_type(), 10), start_interjection_state, ), ) if " " in interjection: i_words = interjection.split() for j, iw in enumerate(i_words): next_interjection_state = fst.add_state() if j == 0: prev_state = start_interjection_state else: prev_state = next_interjection_state - 1 label = word_symbol_table.find(iw) weight = pywrapfst.Weight.one(fst.weight_type()) fst.add_arc( prev_state, pywrapfst.Arc(label, label, weight, next_interjection_state) ) final_interjection_state = next_interjection_state else: final_interjection_state = fst.add_state() label = word_symbol_table.find(interjection) weight = pywrapfst.Weight.one(fst.weight_type()) fst.add_arc( start_interjection_state, pywrapfst.Arc(label, label, weight, final_interjection_state), ) # Path to next word in text weight = pywrapfst.Weight.one(fst.weight_type()) fst.add_arc( final_interjection_state, pywrapfst.Arc( word_symbol_table.find("<eps>"), word_symbol_table.find("<eps>"), weight, start_state, ), ) fst.set_final(next_state, pywrapfst.Weight.one(fst.weight_type())) fst = pynini.determinize(fst) fst = pynini.rmepsilon(fst) fst = pynini.disambiguate(fst) fst = pynini.determinize(fst) return fst def segment_utterance( session: sqlalchemy.orm.Session, working_directory: Path, utterance_id: int, vad_model: VAD, segmentation_options: MetaDict, mfcc_options: MetaDict, pitch_options: MetaDict, lda_options: MetaDict, decode_options: MetaDict, ): log_path = working_directory.joinpath("log", "utterance_segmentation.log") utterance, speaker, dictionary, sound_file = ( session.query(Utterance, Speaker, Dictionary, SoundFile) .join(Utterance.speaker) .join(Speaker.dictionary) .join(Utterance.file) .join(File.sound_file) .filter(Utterance.id == utterance_id) .first() ) text = utterance.normalized_text if not text: text = utterance.text oovs = utterance.oovs.split() normalized_text = " ".join([x if x not in oovs else dictionary.oov_word for x in text.split()]) words = set(normalized_text.split() + [dictionary.bracketed_word]) interjection_words = ( session.query(Word.word) .filter(Word.dictionary_id == dictionary.id) .filter(Word.word_type == WordType.interjection) .all() ) words.update(interjection_words) query = session.query(Word.word, Word.mapping_id, Word.initial_cost, Word.final_cost).filter( Word.dictionary_id == dictionary.id ) initial_costs = {} final_costs = {} reversed_word_mapping = {} for w, m_id, ic, fc in query: reversed_word_mapping[m_id] = w if w not in words: continue if ic is not None: initial_costs[w] = ic if fc is not None: final_costs[w] = fc segments = segment_utterance_vad_speech_brain( utterance, sound_file, vad_model, segmentation_options ) word_symbol_table = pywrapfst.SymbolTable.read_text(dictionary.words_symbol_path) utterance_fst_path = working_directory.joinpath("utterance.fst") utterance_fst = construct_utterance_segmentation_fst( normalized_text, word_symbol_table, interjection_words=interjection_words, ) utterance_fst.write(utterance_fst_path) wav_path = working_directory.joinpath("wav.scp") segment_path = working_directory.joinpath("segments.scp") utt2spk_path = working_directory.joinpath("utt2spk.scp") cmvn_path = working_directory.joinpath("cmvn.scp") trans_path = working_directory.joinpath("trans.scp") with mfa_open(wav_path, "w") as f: f.write(f"{utterance.file_id} {sound_file.sox_string}\n") if speaker.cmvn: with mfa_open(cmvn_path, "w") as f: f.write(f"{utterance.speaker_id} {speaker.cmvn}\n") if speaker.fmllr: with mfa_open(trans_path, "w") as f: f.write(f"{utterance.speaker_id} {speaker.fmllr}\n") sub_utterance_information = {} with mfa_open(segment_path, "w") as f, mfa_open(utt2spk_path, "w") as utt2spk_f: for i in range(segments.shape[0]): begin, end = segments[i] begin = max(begin - 0.05, 0) f.write( f"{utterance.speaker_id}-{i} {utterance.file_id} {begin} {end} {utterance.channel}\n" ) utt2spk_f.write(f"{utterance.speaker_id}-{i} {utterance.speaker_id}\n") sub_utterance_information[i] = { "file_id": utterance.file_id, "begin": float(begin), "end": float(end), "channel": utterance.channel, "speaker_id": utterance.speaker_id, } model_path = working_directory.joinpath("final.alimdl") tree_path = working_directory.joinpath("tree") hclg_path = working_directory.joinpath("hclg.fst") with open(working_directory.joinpath("utterance.text_fst"), "w", encoding="utf8") as f: utterance_fst.set_input_symbols(word_symbol_table) utterance_fst.set_output_symbols(word_symbol_table) f.write(str(utterance_fst)) with mfa_open(log_path, "w") as log_file: proc = subprocess.Popen( [ thirdparty_binary("compile-graph"), f"--read-disambig-syms={dictionary.disambiguation_symbols_int_path}", f"--transition-scale={decode_options['transition_scale']}", f"--self-loop-scale={decode_options['self_loop_scale']}", tree_path, model_path, dictionary.lexicon_disambig_fst_path, utterance_fst_path, f"{hclg_path}", ], stderr=log_file, env=os.environ, ) proc.wait() if proc.returncode != 0: raise KaldiProcessingError([log_path]) feature_proc = online_feature_proc( working_directory, wav_path, segment_path, mfcc_options, pitch_options, lda_options, log_file, ) if decode_options.get("boost_silence", 1.0) != 1.0: mdl_string = f"gmm-boost-silence --boost={decode_options['boost_silence']} {decode_options['optional_silence_csl']} {model_path} - |" else: mdl_string = model_path latgen_proc = subprocess.Popen( [ thirdparty_binary("gmm-latgen-faster"), f"--acoustic-scale={decode_options['acoustic_scale']}", f"--beam={decode_options['beam']}", f"--max-active={decode_options['max_active']}", f"--lattice-beam={decode_options['lattice_beam']}", f"--word-symbol-table={dictionary.words_symbol_path}", "--allow-partial=true", mdl_string, hclg_path, "ark,s,cs:-", "ark:-", ], stderr=log_file, stdin=feature_proc.stdout, stdout=subprocess.PIPE, env=os.environ, ) lat_align_proc = subprocess.Popen( [ thirdparty_binary("lattice-align-words-lexicon"), dictionary.align_lexicon_int_path, mdl_string, "ark,s,cs:-", "ark:-", ], stderr=log_file, stdin=latgen_proc.stdout, stdout=subprocess.PIPE, env=os.environ, ) ctm_proc = subprocess.Popen( [ thirdparty_binary("lattice-to-ctm-conf"), f"--acoustic-scale={decode_options['acoustic_scale']}", "ark,s,cs:-", "-", ], stderr=log_file, stdin=lat_align_proc.stdout, stdout=subprocess.PIPE, env=os.environ, encoding="utf8", ) split_texts = {} for sub_id, intervals in parse_ctm_output(ctm_proc, reversed_word_mapping): split_text = " ".join([x.label for x in intervals if x.confidence == 1.0]) if not split_text: del sub_utterance_information[sub_id] continue split_texts[sub_id] = split_text ctm_proc.wait() split_texts = align_text(split_texts, text, oovs, dictionary.oov_word, interjection_words) for i, split_text in split_texts.items(): split_oovs = set(x for x in oovs if x in split_text.split()) sub_utterance_information[i]["text"] = split_text sub_utterance_information[i]["oovs"] = " ".join(split_oovs) sub_utterance_information[i]["normalized_text"] = split_text sub_utterance_information = { k: v for k, v in sub_utterance_information.items() if "text" in v } return utterance_id, sub_utterance_information def align_text(split_utterance_texts, text, oovs, oov_word, interjection_words): text = text.split() split_utterance_text = [] lengths = [] indices = list(split_utterance_texts.keys()) for t in split_utterance_texts.values(): t = t.split() lengths.append(len(t)) split_utterance_text.extend(t) def score_func(first_element, second_element): if first_element == second_element: return 0 if first_element == oov_word and second_element in oovs: return 0 if first_element == oov_word and second_element not in oovs: return -10 if first_element in interjection_words: return -10 return -2 alignments = pairwise2.align.globalcs( split_utterance_text, text, score_func, -0.5, -0.1, gap_char=["-"], one_alignment_only=True ) results = [[]] split_ind = 0 current_size = 0 for a in alignments: for i, sa in enumerate(a.seqA): sb = a.seqB[i] if sa == "<unk>": sa = sb if sa != "-": if ( split_ind < len(lengths) - 1 and sa not in split_utterance_texts[indices[split_ind]].split() and split_utterance_texts[indices[split_ind + 1]].split()[0] == sa ): results.append([]) split_ind += 1 current_size = 0 results[-1].append(sa) current_size += 1 if split_ind < len(lengths) - 1 and current_size >= lengths[split_ind]: results.append([]) split_ind += 1 current_size = 0 elif sb != "-": results[-1].append(sb) results = {k: " ".join(r) for k, r in zip(split_utterance_texts.keys(), results)} return results def segment_utterance_vad_speech_brain( utterance: Utterance, sound_file: SoundFile, vad_model: VAD, segmentation_options: MetaDict ) -> np.ndarray: y, _ = librosa.load( sound_file.sound_file_path, sr=16000, mono=False, offset=utterance.begin, duration=utterance.duration, ) if len(y.shape) > 1: y = y[:, utterance.channel] prob_chunks = vad_model.get_speech_prob_chunk( torch.tensor(y[np.newaxis, :], device=vad_model.device) ).cpu() prob_th = vad_model.apply_threshold( prob_chunks, activation_th=segmentation_options["activation_th"], deactivation_th=segmentation_options["deactivation_th"], ).float() # Compute the boundaries of the speech segments boundaries = vad_model.get_boundaries(prob_th, output_value="seconds") boundaries += utterance.begin # Apply energy-based VAD on the detected speech segments if True or segmentation_options["apply_energy_VAD"]: boundaries = vad_model.energy_VAD( sound_file.sound_file_path, boundaries, activation_th=segmentation_options["en_activation_th"], deactivation_th=segmentation_options["en_deactivation_th"], ) # Merge short segments boundaries = vad_model.merge_close_segments( boundaries, close_th=segmentation_options["close_th"] ) # Remove short segments boundaries = vad_model.remove_short_segments(boundaries, len_th=segmentation_options["len_th"]) # Double check speech segments if segmentation_options["double_check"]: boundaries = vad_model.double_check_speech_segments( boundaries, sound_file.sound_file_path, speech_th=segmentation_options["speech_th"] ) boundaries[:, 0] -= round(segmentation_options["close_th"] / 3, 3) boundaries[:, 1] += round(segmentation_options["close_th"] / 3, 3) return boundaries.numpy()
[docs] class SegmentVadFunction(KaldiFunction): """ Multiprocessing function to generate segments from VAD output. See Also -------- :meth:`montreal_forced_aligner.segmenter.Segmenter.segment_vad` Main function that calls this function in parallel :meth:`montreal_forced_aligner.segmenter.Segmenter.segment_vad_arguments` Job method for generating arguments for this function :kaldi_utils:`segmentation.pl` Kaldi utility Parameters ---------- args: :class:`~montreal_forced_aligner.segmenter.SegmentVadArguments` Arguments for the function """ progress_pattern = re.compile( r"^LOG.*processed (?P<done>\d+) utterances.*(?P<no_feats>\d+) had.*(?P<unvoiced>\d+) were.*" ) def __init__(self, args: SegmentVadArguments): super().__init__(args) self.vad_path = args.vad_path self.segmentation_options = args.segmentation_options def _run(self) -> typing.Generator[typing.Tuple[int, float, float]]: """Run the function""" with mfa_open(self.log_path, "w") as log_file: copy_proc = subprocess.Popen( [ thirdparty_binary("copy-vector"), "--binary=false", f"scp:{self.vad_path}", "ark,t:-", ], stdout=subprocess.PIPE, stderr=log_file, env=os.environ, ) for utt_id, frames in read_feats(copy_proc): initial_segments = get_initial_segmentation( frames, self.segmentation_options["frame_shift"] ) merged = merge_segments( initial_segments, self.segmentation_options["close_th"], self.segmentation_options["large_chunk_size"], self.segmentation_options["len_th"], ) yield utt_id, merged