Source code for aligner.corpus

import os
import subprocess
import sys
import traceback
import shutil
import struct
import wave
import logging
from collections import defaultdict, Counter
from textgrid import TextGrid, IntervalTier

from .helper import thirdparty_binary, load_text, make_safe
from .multiprocessing import mfcc

from .exceptions import SampleRateError, CorpusError

from .dictionary import sanitize

from .config import MfccConfig


def output_mapping(mapping, path):
    with open(path, 'w', encoding='utf8') as f:
        for k in sorted(mapping.keys()):
            v = mapping[k]
            if isinstance(v, list):
                v = ' '.join(v)
            f.write('{} {}\n'.format(k, v))


def save_scp(scp, path, sort=True, multiline=False):
    with open(path, 'w', encoding='utf8') as f:
        if sort:
            scp = sorted(scp)
        for line in scp:
            if multiline:
                f.write('{}\n{}\n'.format(make_safe(line[0]), make_safe(line[1])))
            else:
                f.write('{}\n'.format(' '.join(map(make_safe, line))))


def save_groups(groups, seg_dir, pattern, multiline=False):
    for i, g in enumerate(groups):
        path = os.path.join(seg_dir, pattern.format(i))
        save_scp(g, path, multiline=multiline)


def load_scp(path):
    '''
    Load a Kaldi script file (.scp)

    See http://kaldi-asr.org/doc/io.html#io_sec_scp_details for more information

    Parameters
    ----------
    path : str
        Path to Kaldi script file

    Returns
    -------
    dict
        Dictionary where the keys are the first couple and the values are all
        other columns in the script file

    '''
    scp = {}
    with open(path, 'r', encoding='utf8') as f:
        for line in f:
            line = line.strip()
            if line == '':
                continue
            line_list = line.split()
            key = line_list.pop(0)
            if len(line_list) == 1:
                value = line_list[0]
            else:
                value = line_list
            scp[key] = value
    return scp


def find_lab(filename, files):
    '''
    Finds a .lab or .txt file that corresponds to a wav file.  The .lab extension is given priority.

    Parameters
    ----------
    filename : str
        Name of wav file
    files : list
        List of files to search in

    Returns
    -------
    str or None
        If a corresponding .lab or .txt file is found, returns it, otherwise returns None
    '''
    name, ext = os.path.splitext(filename)
    for f in files:
        fn, fext = os.path.splitext(f)
        if fn == name and fext.lower() == '.lab':
            return f
    for f in files: # Use .txt if no .lab file available
        fn, fext = os.path.splitext(f)
        if fn == name and fext.lower() == '.txt':
            return f
    return None


def find_textgrid(filename, files):
    '''
    Finds a TextGrid file that corresponds to a wav file

    Parameters
    ----------
    filename : str
        Name of wav file
    files : list
        List of files to search in

    Returns
    -------
    str or None
        If a corresponding TextGrid is found, returns it, otherwise returns None
    '''
    name, ext = os.path.splitext(filename)
    for f in files:
        fn, fext = os.path.splitext(f)
        if fn == name and fext.lower() == '.textgrid':
            return f
    return None


def get_n_channels(file_path):
    '''
    Return the number of channels for a sound file

    Parameters
    ----------
    file_path : str
        Path to a wav file

    Returns
    -------
    int
        Number of channels (1 if mono, 2 if stereo)
    '''

    with wave.open(file_path, 'rb') as soundf:
        n_channels = soundf.getnchannels()
    return n_channels


def get_sample_rate(file_path):
    with wave.open(file_path, 'rb') as soundf:
        sr = soundf.getframerate()
    return sr


def get_wav_duration(file_path):
    with wave.open(file_path, 'rb') as soundf:
        sr = soundf.getframerate()
        nframes = soundf.getnframes()
    return nframes / sr


def extract_temp_channels(wav_path, temp_directory):
    '''
    Extract a single channel from a stereo file to a new mono wav file

    Parameters
    ----------
    wav_path : str
        Path to stereo wav file
    temp_directory : str
        Directory to save extracted
    '''
    name, ext = os.path.splitext(wav_path)
    base = os.path.basename(name)
    A_path = os.path.join(temp_directory, base + '_A.wav')
    B_path = os.path.join(temp_directory, base + '_B.wav')
    samp_step = 1000000
    if not os.path.exists(A_path):
        with wave.open(wav_path, 'rb') as inf, \
                wave.open(A_path, 'wb') as af, \
                wave.open(B_path, 'wb') as bf:
            chans = inf.getnchannels()
            samps = inf.getnframes()
            samplerate = inf.getframerate()
            sampwidth = inf.getsampwidth()
            assert sampwidth == 2
            af.setnchannels(1)
            af.setframerate(samplerate)
            af.setsampwidth(sampwidth)
            bf.setnchannels(1)
            bf.setframerate(samplerate)
            bf.setsampwidth(sampwidth)
            cur_samp = 0
            while cur_samp < samps:
                s = inf.readframes(samp_step)
                cur_samp += samp_step
                act = samp_step
                if cur_samp > samps:
                    act -= (cur_samp - samps)

                unpstr = '<{0}h'.format(act * chans)  # little-endian 16-bit samples
                x = list(struct.unpack(unpstr, s))  # convert the byte string into a list of ints
                values = [struct.pack('h', d) for d in x[0::chans]]
                value_str = b''.join(values)
                af.writeframes(value_str)
                values = [struct.pack('h', d) for d in x[1::chans]]
                value_str = b''.join(values)
                bf.writeframes(value_str)
    return A_path, B_path


[docs]class Corpus(object): ''' Class that stores information about the dataset to align. Corpus objects have a number of mappings from either utterances or speakers to various properties, and mappings between utterances and speakers. See http://kaldi-asr.org/doc/data_prep.html for more information about the files that are created by this class. Parameters ---------- directory : str Directory of the dataset to align output_directory : str Directory to store generated data for the Kaldi binaries mfcc_config : MfccConfig Configuration object for how to calculate MFCCs speaker_characters : int, optional Number of characters in the filenames to count as the speaker ID, if not specified, speaker IDs are generated from directory names num_jobs : int, optional Number of processes to use, defaults to 3 Raises ------ CorpusError Raised if the specified corpus directory does not exist SampleRateError Raised if the wav files in the dataset do not share a consistent sample rate ''' def __init__(self, directory, output_directory, use_speaker_information=True, speaker_characters=0, num_jobs=3, debug=False, ignore_exceptions=False): self.debug = debug log_dir = os.path.join(output_directory, 'logging') os.makedirs(log_dir, exist_ok=True) self.log_file = os.path.join(log_dir, 'corpus.log') root_logger = logging.getLogger() root_logger.setLevel(logging.INFO) handler = logging.FileHandler(self.log_file, 'w', 'utf-8') handler.setFormatter = logging.Formatter('%(name)s %(message)s') root_logger.addHandler(handler) if not os.path.exists(directory): raise (CorpusError('The directory \'{}\' does not exist.'.format(directory))) if not os.path.isdir(directory): raise (CorpusError('The specified path for the corpus ({}) is not a directory.'.format(directory))) if num_jobs < 1: num_jobs = 1 print('Setting up corpus information...') root_logger.info('Setting up corpus information...') self.directory = directory self.output_directory = os.path.join(output_directory, 'train') self.temp_directory = os.path.join(self.output_directory, 'temp') os.makedirs(self.temp_directory, exist_ok=True) self.num_jobs = num_jobs # Set up mapping dictionaries self.speak_utt_mapping = defaultdict(list) self.utt_speak_mapping = {} self.utt_wav_mapping = {} self.text_mapping = {} self.word_counts = Counter() self.segments = {} self.feat_mapping = {} self.cmvn_mapping = {} self.ignored_utterances = [] self.wav_files = [] self.wav_durations = {} feat_path = os.path.join(self.output_directory, 'feats.scp') if os.path.exists(feat_path): self.feat_mapping = load_scp(feat_path) if speaker_characters == 0: self.speaker_directories = True else: self.speaker_directories = False self.sample_rates = defaultdict(set) no_transcription_files = [] decode_error_files = [] unsupported_sample_rate = [] ignored_duplicates = False textgrid_read_errors = {} for root, dirs, files in os.walk(self.directory, followlinks=True): for f in sorted(files): file_name, ext = os.path.splitext(f) if ext.lower() != '.wav': continue lab_name = find_lab(f, files) wav_path = os.path.join(root, f) sr = get_sample_rate(wav_path) if sr < 16000: unsupported_sample_rate.append(wav_path) continue if lab_name is not None: utt_name = file_name if utt_name in self.utt_wav_mapping: if not ignore_exceptions: prev_wav = self.utt_wav_mapping[utt_name] raise CorpusError( 'Files with the same file name are not permitted. Files with the same name are: {}, {}.'.format( prev_wav, wav_path)) else: ignored_duplicates = True ind = 0 fixed_utt_name = utt_name while fixed_utt_name not in self.utt_wav_mapping: ind += 1 fixed_utt_name = utt_name + '_{}'.format(ind) utt_name = fixed_utt_name if self.feat_mapping and utt_name not in self.feat_mapping: self.ignored_utterances.append(utt_name) continue lab_path = os.path.join(root, lab_name) try: text = load_text(lab_path) except UnicodeDecodeError: decode_error_files.append(lab_path) continue words = [sanitize(x) for x in text.split()] words = [x for x in words if x not in ['', '-', "'"]] if not words: continue self.word_counts.update(words) self.text_mapping[utt_name] = ' '.join(words) if self.speaker_directories: speaker_name = os.path.basename(root) else: if isinstance(speaker_characters, int): speaker_name = f[:speaker_characters] elif speaker_characters == 'prosodylab': speaker_name = f.split('_')[1] speaker_name = speaker_name.strip().replace(' ', '_') utt_name = utt_name.strip().replace(' ', '_') self.speak_utt_mapping[speaker_name].append(utt_name) self.utt_wav_mapping[utt_name] = wav_path self.sample_rates[get_sample_rate(wav_path)].add(speaker_name) self.utt_speak_mapping[utt_name] = speaker_name else: tg_name = find_textgrid(f, files) if tg_name is None: no_transcription_files.append(wav_path) continue self.wav_files.append(file_name) self.wav_durations[file_name] = get_wav_duration(wav_path) tg_path = os.path.join(root, tg_name) tg = TextGrid() try: tg.read(tg_path) except Exception as e: exc_type, exc_value, exc_traceback = sys.exc_info() textgrid_read_errors[tg_path] = traceback.format_exception(exc_type, exc_value, exc_traceback) n_channels = get_n_channels(wav_path) num_tiers = len(tg.tiers) if n_channels == 2: A_name = file_name + "_A" B_name = file_name + "_B" A_path, B_path = extract_temp_channels(wav_path, self.temp_directory) elif n_channels > 2: raise (Exception('More than two channels')) if not self.speaker_directories: if isinstance(speaker_characters, int): speaker_name = f[:speaker_characters] elif speaker_characters == 'prosodylab': speaker_name = f.split('_')[1] speaker_name = speaker_name.strip().replace(' ', '_') for i, ti in enumerate(tg.tiers): if ti.name.lower() == 'notes': continue if not isinstance(ti, IntervalTier): continue if self.speaker_directories: speaker_name = ti.name.strip().replace(' ', '_') self.sample_rates[get_sample_rate(wav_path)].add(speaker_name) for interval in ti: label = interval.mark.lower().strip() label = sanitize(label) words = [sanitize(x) for x in label.split()] words = [x for x in words if x not in ['', '-', "'"]] if not words: continue begin, end = round(interval.minTime, 4), round(interval.maxTime, 4) utt_name = '{}_{}_{}_{}'.format(speaker_name, file_name, begin, end) utt_name = utt_name.strip().replace(' ', '_').replace('.', '_') if n_channels == 1: if self.feat_mapping and utt_name not in self.feat_mapping: self.ignored_utterances.append(utt_name) continue self.segments[utt_name] = '{} {} {}'.format(file_name, begin, end) self.utt_wav_mapping[file_name] = wav_path else: if i < num_tiers / 2: utt_name += '_A' if self.feat_mapping and utt_name not in self.feat_mapping: self.ignored_utterances.append(utt_name) continue self.segments[utt_name] = '{} {} {}'.format(A_name, begin, end) self.utt_wav_mapping[A_name] = A_path else: utt_name += '_B' if self.feat_mapping and utt_name not in self.feat_mapping: self.ignored_utterances.append(utt_name) continue self.segments[utt_name] = '{} {} {}'.format(B_name, begin, end) self.utt_wav_mapping[B_name] = B_path self.text_mapping[utt_name] = ' '.join(words) self.word_counts.update(words) self.utt_speak_mapping[utt_name] = speaker_name self.speak_utt_mapping[speaker_name].append(utt_name) if ignored_duplicates: print('At least one duplicate wav file name was found and treated as a different utterance.') if len(self.ignored_utterances) > 0: print('{} utterance(s) were ignored due to lack of features, please see {} for more information.'.format( len(self.ignored_utterances), self.log_file)) root_logger.warning( 'The following utterances were ignored due to lack of features: {}. ' 'See relevant logs for more information'.format(', '.join(self.ignored_utterances))) if len(no_transcription_files) > 0: print( '{} wav file(s) were ignored because neither a .lab file or a .TextGrid file could be found, ' 'please see {} for more information'.format(len(no_transcription_files), self.log_file)) root_logger.warning( 'The following wav files were ignored due to lack of of a .lab or a .TextGrid file: {}.'.format( ', '.join(no_transcription_files))) if textgrid_read_errors: print('{} TextGrid files were ignored due to errors loading them. ' 'Please see {} for more information on the errors.'.format(len(textgrid_read_errors), self.log_file)) for k, v in textgrid_read_errors.items(): root_logger.warning('The TextGrid file {} gave the following error on load:\n\n{}'.format(k, v)) if len(unsupported_sample_rate) > 0: print( '{} wav file(s) were ignored because they had a sample rate less than 16000, ' 'which is not currently supported, please see {} for more information'.format( len(unsupported_sample_rate), self.log_file)) root_logger.warning( 'The following wav files were ignored due to a sample rate lower than 16000: {}.'.format( ', '.join(unsupported_sample_rate))) if decode_error_files: print('There was an issue reading {} text file(s). ' 'Please see {} for more information.'.format(len(decode_error_files), self.log_file)) root_logger.warning( 'The following lab files were ignored because they could not be parsed with utf8: {}.'.format( ', '.join(decode_error_files))) bad_speakers = [] for speaker in self.speak_utt_mapping.keys(): count = 0 for k, v in self.sample_rates.items(): if speaker in v: count += 1 if count > 1: bad_speakers.append(speaker) if bad_speakers: msg = 'The following speakers had multiple speaking rates: {}. Please make sure that each speaker has a consistent sampling rate.'.format( ', '.join(bad_speakers)) root_logger.error(msg) raise (SampleRateError(msg)) if len(self.speak_utt_mapping) < self.num_jobs: self.num_jobs = len(self.speak_utt_mapping) if self.num_jobs < len(self.sample_rates.keys()): self.num_jobs = len(self.sample_rates.keys()) msg = 'The number of jobs was set to {}, due to the different sample rates in the dataset. If you would like to use fewer parallel jobs, please resample all wav files to the same sample rate.'.format( self.num_jobs) print(msg) root_logger.warning(msg) self.find_best_groupings() @property def word_set(self): return set(self.word_counts) def find_best_groupings(self): if self.segments: ratio = len(self.segments.keys()) / len(self.utt_speak_mapping.keys()) segment_job_num = int(ratio * self.num_jobs) if segment_job_num == 0: segment_job_num = 1 else: segment_job_num = 0 full_wav_job_num = self.num_jobs - segment_job_num num_sample_rates = len(self.sample_rates.keys()) jobs_per_sample_rate = {x: 1 for x in self.sample_rates.keys()} remaining_jobs = self.num_jobs - num_sample_rates while remaining_jobs > 0: min_num = min(jobs_per_sample_rate.values()) addable = sorted([k for k, v in jobs_per_sample_rate.items() if v == min_num], key=lambda x: -1 * len(self.sample_rates[x])) jobs_per_sample_rate[addable[0]] += 1 remaining_jobs -= 1 self.speaker_groups = [] self.mfcc_configs = [] job_num = 0 for k, v in jobs_per_sample_rate.items(): speakers = sorted(self.sample_rates[k]) groups = [[] for x in range(v)] configs = [MfccConfig(self.mfcc_directory, job=job_num + x, kwargs={'sample-frequency': k, 'low-freq': 20, 'high-freq': 7800}) for x in range(v)] ind = 0 while speakers: s = speakers.pop(0) groups[ind].append(s) ind += 1 if ind >= v: ind = 0 job_num += v self.speaker_groups.extend(groups) self.mfcc_configs.extend(configs) self.groups = [] for x in self.speaker_groups: g = [] for s in x: g.extend(self.speak_utt_mapping[s]) self.groups.append(g) def speaker_utterance_info(self): num_speakers = len(self.speak_utt_mapping.keys()) average_utterances = sum(len(x) for x in self.speak_utt_mapping.values()) / num_speakers msg = 'Number of speakers in corpus: {}, average number of utterances per speaker: {}'.format(num_speakers, average_utterances) root_logger = logging.getLogger() root_logger.info(msg) return msg def parse_mfcc_logs(self): pass @property def num_utterances(self): return len(self.utt_speak_mapping) @property def mfcc_directory(self): return os.path.join(self.output_directory, 'mfcc') @property def mfcc_log_directory(self): return os.path.join(self.mfcc_directory, 'log') @property def grouped_wav(self): output = [] for g in self.groups: done = set() output_g = [] for u in g: if not self.segments: try: output_g.append([u, self.utt_wav_mapping[u]]) except KeyError: pass else: try: r = self.segments[u].split(' ')[0] except KeyError: continue if r not in done: output_g.append([r, self.utt_wav_mapping[r]]) done.add(r) output.append(output_g) return output @property def grouped_feat(self): output = [] for g in self.groups: output_g = [] for u in g: try: output_g.append([u, self.feat_mapping[u]]) except KeyError: pass output.append(output_g) return output def grouped_text(self, dictionary=None): output = [] for g in self.groups: output_g = [] for u in g: if dictionary is None: try: text = self.text_mapping[u] except KeyError: continue else: try: text = self.text_mapping[u].split() except KeyError: continue new_text = [] for t in text: lookup = dictionary.separate_clitics(t) if lookup is None: continue new_text.extend(x for x in lookup if x != '') output_g.append([u, new_text]) output.append(output_g) return output def grouped_text_int(self, dictionary): oov_code = dictionary.oov_int all_oovs = [] output = [] grouped_texts = self.grouped_text(dictionary) for g in grouped_texts: output_g = [] for u, text in g: oovs = [] for i in range(len(text)): t = text[i] lookup = dictionary.to_int(t) if lookup is None: continue if lookup == oov_code: oovs.append(t) text[i] = lookup if oovs: all_oovs.append(u + ' ' + ', '.join(oovs)) new_text = map(str, (x for x in text if isinstance(x, int))) output_g.append([u, ' '.join(new_text)]) output.append(output_g) return output, all_oovs @property def grouped_cmvn(self): output = [] try: for g in self.speaker_groups: output_g = [] for s in sorted(g): try: output_g.append([s, self.cmvn_mapping[s]]) except KeyError: pass output.append(output_g) except KeyError: raise (CorpusError( 'Something went wrong while setting up the corpus. Please delete the {} folder and try again.'.format( self.output_directory))) return output @property def grouped_utt2spk(self): output = [] for g in self.groups: output_g = [] for u in sorted(g): try: output_g.append([u, self.utt_speak_mapping[u]]) except KeyError: pass output.append(output_g) return output def get_word_frquency(self, dictionary): word_counts = Counter() for u, text in self.text_mapping.items(): new_text = [] text = text.split() for t in text: lookup = dictionary.separate_clitics(t) if lookup is None: continue new_text.extend(x for x in lookup if x != '') word_counts.update(new_text) return {k: v / sum(word_counts.values()) for k, v in word_counts.items()} def grouped_utt2fst(self, dictionary, num_frequent_words=10): word_frequencies = self.get_word_frquency(dictionary) most_frequent = sorted(word_frequencies.items(), key=lambda x: -x[1])[:num_frequent_words] output = [] for g in self.groups: output_g = [] for u in g: try: text = self.text_mapping[u].split() except KeyError: continue new_text = [] for t in text: lookup = dictionary.separate_clitics(t) if lookup is None: continue new_text.extend(x for x in lookup if x != '') try: fst_text = dictionary.create_utterance_fst(new_text, most_frequent) except ZeroDivisionError: print(u, text, new_text) raise output_g.append([u, fst_text]) output.append(output_g) return output @property def grouped_segments(self): output = [] for g in self.groups: output_g = [] for u in g: try: output_g.append([u, self.segments[u]]) except KeyError: pass output.append(output_g) return output @property def grouped_spk2utt(self): output = [] for g in self.speaker_groups: output_g = [] for s in sorted(g): try: output_g.append([s, sorted(self.speak_utt_mapping[s])]) except KeyError: pass output.append(output_g) return output def get_wav_duration(self, utt): if utt in self.wav_durations: return self.wav_durations[utt] if not self.segments: wav_path = self.utt_wav_mapping[utt] else: rec = self.segments[utt].split(' ')[0] wav_path = self.utt_wav_mapping[rec] with wave.open(wav_path, 'rb') as soundf: sr = soundf.getframerate() nframes = soundf.getnframes() return nframes / sr @property def split_directory(self): return os.path.join(self.output_directory, 'split{}'.format(self.num_jobs)) def write(self): self._write_speak_utt() self._write_utt_speak() self._write_text() self._write_wavscp() def _write_utt_speak(self): utt2spk = os.path.join(self.output_directory, 'utt2spk') output_mapping(self.utt_speak_mapping, utt2spk) def _write_speak_utt(self): spk2utt = os.path.join(self.output_directory, 'spk2utt') output_mapping(self.speak_utt_mapping, spk2utt) def _write_text(self): text = os.path.join(self.output_directory, 'text') output_mapping(self.text_mapping, text) def _write_wavscp(self): wavscp = os.path.join(self.output_directory, 'wav.scp') output_mapping(self.utt_wav_mapping, wavscp) def _write_segments(self): if not self.segments: return segments = os.path.join(self.output_directory, 'segments') output_mapping(self.segments, segments) def _split_utt2spk(self, directory): pattern = 'utt2spk.{}' save_groups(self.grouped_utt2spk, directory, pattern) def _split_utt2fst(self, directory, dictionary): pattern = 'utt2fst.{}' save_groups(self.grouped_utt2fst(dictionary), directory, pattern, multiline=True) def _split_segments(self, directory): if not self.segments: return pattern = 'segments.{}' save_groups(self.grouped_segments, directory, pattern) def _split_spk2utt(self, directory): pattern = 'spk2utt.{}' save_groups(self.grouped_spk2utt, directory, pattern) def _split_wavs(self, directory): pattern = 'wav.{}.scp' save_groups(self.grouped_wav, directory, pattern) def _split_feats(self, directory): if not self.feat_mapping: feat_path = os.path.join(self.output_directory, 'feats.scp') self.feat_mapping = load_scp(feat_path) pattern = 'feats.{}.scp' save_groups(self.grouped_feat, directory, pattern) def _split_texts(self, directory, dictionary=None): pattern = 'text.{}' save_groups(self.grouped_text(dictionary), directory, pattern) if dictionary is not None: pattern = 'text.{}.int' ints, all_oovs = self.grouped_text_int(dictionary) save_groups(ints, directory, pattern) if all_oovs: with open(os.path.join(directory, 'utterance_oovs.txt'), 'w', encoding='utf8') as f: for oov in sorted(all_oovs): f.write(oov + '\n') dictionary.save_oovs_found(directory) def _split_cmvns(self, directory): if not self.cmvn_mapping: cmvn_path = os.path.join(self.output_directory, 'cmvn.scp') self.cmvn_mapping = load_scp(cmvn_path) pattern = 'cmvn.{}.scp' save_groups(self.grouped_cmvn, directory, pattern) def create_mfccs(self): log_directory = self.mfcc_log_directory os.makedirs(log_directory, exist_ok=True) if os.path.exists(os.path.join(self.mfcc_directory, 'cmvn')): print("Using previous MFCCs") return print('Calculating MFCCs...') self._split_wavs(self.mfcc_log_directory) self._split_segments(self.mfcc_log_directory) mfcc(self.mfcc_directory, log_directory, self.num_jobs, self.mfcc_configs) self.parse_mfcc_logs() self._combine_feats() print('Calculating CMVN...') self._calc_cmvn() def _combine_feats(self): root_logger = logging.getLogger() self.feat_mapping = {} feat_path = os.path.join(self.output_directory, 'feats.scp') with open(feat_path, 'w') as outf: for i in range(self.num_jobs): path = os.path.join(self.mfcc_directory, 'raw_mfcc.{}.scp'.format(i)) with open(path, 'r') as inf: for line in inf: line = line.strip() if line == '': continue f = line.split(maxsplit=1) self.feat_mapping[f[0]] = f[1] outf.write(line + '\n') os.remove(path) if len(self.feat_mapping.keys()) != len(self.utt_speak_mapping.keys()): for k in self.utt_speak_mapping.keys(): if k not in self.feat_mapping: self.ignored_utterances.append(k) print('Some utterances were ignored due to lack of features, please see {} for more information.'.format( self.log_file)) root_logger.warning( 'The following utterances were ignored due to lack of features: {}. See relevant logs for more information'.format( ', '.join(self.ignored_utterances))) for k in self.ignored_utterances: del self.utt_speak_mapping[k] try: del self.utt_wav_mapping[k] except KeyError: pass try: del self.segments[k] except KeyError: pass try: del self.text_mapping[k] except KeyError: pass for k, v in self.speak_utt_mapping.items(): self.speak_utt_mapping[k] = list(filter(lambda x: x in self.feat_mapping, v)) def _calc_cmvn(self): spk2utt = os.path.join(self.output_directory, 'spk2utt') feats = os.path.join(self.output_directory, 'feats.scp') cmvn_directory = os.path.join(self.mfcc_directory, 'cmvn') os.makedirs(cmvn_directory, exist_ok=True) cmvn_ark = os.path.join(cmvn_directory, 'cmvn.ark') cmvn_scp = os.path.join(cmvn_directory, 'cmvn.scp') log_path = os.path.join(cmvn_directory, 'cmvn.log') with open(log_path, 'w') as logf: subprocess.call([thirdparty_binary('compute-cmvn-stats'), '--spk2utt=ark:' + spk2utt, 'scp:' + feats, 'ark,scp:{},{}'.format(cmvn_ark, cmvn_scp)], stderr=logf) shutil.copy(cmvn_scp, os.path.join(self.output_directory, 'cmvn.scp')) self.cmvn_mapping = load_scp(cmvn_scp) def _split_and_norm_feats(self): split_dir = self.split_directory log_dir = os.path.join(split_dir, 'log') os.makedirs(log_dir, exist_ok=True) with open(os.path.join(log_dir, 'norm.log'), 'w') as logf: for i in range(self.num_jobs): path = os.path.join(split_dir, 'cmvndeltafeats.{}'.format(i)) utt2spkpath = os.path.join(split_dir, 'utt2spk.{}'.format(i)) cmvnpath = os.path.join(split_dir, 'cmvn.{}.scp'.format(i)) featspath = os.path.join(split_dir, 'feats.{}.scp'.format(i)) if not os.path.exists(path): with open(path, 'wb') as outf: cmvn_proc = subprocess.Popen([thirdparty_binary('apply-cmvn'), '--utt2spk=ark:' + utt2spkpath, 'scp:' + cmvnpath, 'scp:' + featspath, 'ark:-'], stdout=subprocess.PIPE, stderr=logf ) deltas_proc = subprocess.Popen([thirdparty_binary('add-deltas'), 'ark:-', 'ark:-'], stdin=cmvn_proc.stdout, stdout=outf, stderr=logf ) deltas_proc.communicate() with open(path, 'rb') as inf, open(path + '_sub', 'wb') as outf: subprocess.call([thirdparty_binary("subset-feats"), "--n=10", "ark:-", "ark:-"], stdin=inf, stderr=logf, stdout=outf) def get_feat_dim(self): directory = self.split_directory path = os.path.join(self.split_directory, 'cmvndeltafeats.0') with open(path, 'rb') as f, open(os.devnull, 'w') as devnull: dim_proc = subprocess.Popen([thirdparty_binary('feat-to-dim'), 'ark,s,cs:-', '-'], stdin=f, stdout=subprocess.PIPE, stderr=devnull) stdout, stderr = dim_proc.communicate() feats = stdout.decode('utf8').strip() return feats def initialize_corpus(self, dictionary, skip_input=False): root_logger = logging.getLogger() split_dir = self.split_directory self.write() split = False if not os.path.exists(split_dir): split = True root_logger.info('Setting up training data...') print('Setting up training data...') os.makedirs(split_dir) self._split_wavs(split_dir) self._split_utt2spk(split_dir) self._split_spk2utt(split_dir) self._split_texts(split_dir, dictionary) self._split_utt2fst(split_dir, dictionary) if not skip_input and dictionary.oovs_found: user_input = input( 'There were words not found in the dictionary. Would you like to abort to fix them? (Y/N)') if user_input.lower() == 'y': sys.exit(1) self.create_mfccs() if split: self._split_feats(split_dir) self._split_cmvns(split_dir) self._split_and_norm_feats()