"""
Speaker classification
======================
"""
from __future__ import annotations
import collections
import csv
import logging
import os
import pickle
import random
import shutil
import subprocess
import sys
import time
import typing
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import numpy as np
import sqlalchemy
import tqdm
import yaml
from sklearn import decomposition, metrics
from sqlalchemy.orm import joinedload, selectinload
from montreal_forced_aligner.abc import FileExporterMixin, TopLevelMfaWorker
from montreal_forced_aligner.alignment.multiprocessing import construct_output_path
from montreal_forced_aligner.config import (
GLOBAL_CONFIG,
IVECTOR_DIMENSION,
MEMORY,
PLDA_DIMENSION,
XVECTOR_DIMENSION,
)
from montreal_forced_aligner.corpus.features import (
ExportIvectorsArguments,
ExportIvectorsFunction,
PldaModel,
)
from montreal_forced_aligner.corpus.ivector_corpus import IvectorCorpusMixin
from montreal_forced_aligner.data import (
ClusterType,
DistanceMetric,
ManifoldAlgorithm,
WorkflowType,
)
from montreal_forced_aligner.db import (
Corpus,
File,
SoundFile,
Speaker,
SpeakerOrdering,
TextFile,
Utterance,
bulk_update,
)
from montreal_forced_aligner.diarization.multiprocessing import (
ComputeEerArguments,
ComputeEerFunction,
PldaClassificationArguments,
PldaClassificationFunction,
SpeechbrainArguments,
SpeechbrainClassificationFunction,
SpeechbrainEmbeddingFunction,
cluster_matrix,
visualize_clusters,
)
from montreal_forced_aligner.exceptions import KaldiProcessingError
from montreal_forced_aligner.helper import load_configuration, mfa_open
from montreal_forced_aligner.models import IvectorExtractorModel
from montreal_forced_aligner.textgrid import export_textgrid
from montreal_forced_aligner.utils import log_kaldi_errors, run_kaldi_function, thirdparty_binary
try:
import warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore")
torch_logger = logging.getLogger("speechbrain.utils.torch_audio_backend")
torch_logger.setLevel(logging.ERROR)
torch_logger = logging.getLogger("speechbrain.utils.train_logger")
torch_logger.setLevel(logging.ERROR)
import torch
from speechbrain.pretrained import EncoderClassifier, SpeakerRecognition
from speechbrain.utils.metric_stats import EER
FOUND_SPEECHBRAIN = True
except (ImportError, OSError):
FOUND_SPEECHBRAIN = False
EncoderClassifier = None
if TYPE_CHECKING:
from montreal_forced_aligner.abc import MetaDict
__all__ = ["SpeakerDiarizer"]
logger = logging.getLogger("mfa")
[docs]
class SpeakerDiarizer(IvectorCorpusMixin, TopLevelMfaWorker, FileExporterMixin):
"""
Class for performing speaker classification, not currently very functional, but
is planned to be expanded in the future
Parameters
----------
ivector_extractor_path : str
Path to ivector extractor model, or "speechbrain"
expected_num_speakers: int, optional
Number of speakers in the corpus, if known
cluster: bool
Flag for whether speakers should be clustered instead of classified
evaluation_mode: bool
Flag for evaluating against existing speaker labels
cuda: bool
Flag for using CUDA for speechbrain models
metric: str or :class:`~montreal_forced_aligner.data.DistanceMetric`
One of "cosine", "plda", or "euclidean"
cluster_type: str or :class:`~montreal_forced_aligner.data.ClusterType`
Clustering algorithm
relative_distance_threshold: float
Threshold to use clustering based on distance
"""
def __init__(
self,
ivector_extractor_path: str = "speechbrain",
expected_num_speakers: int = 0,
cluster: bool = True,
evaluation_mode: bool = False,
cuda: bool = False,
use_pca: bool = True,
metric: typing.Union[str, DistanceMetric] = "cosine",
cluster_type: typing.Union[str, ClusterType] = "hdbscan",
manifold_algorithm: typing.Union[str, ManifoldAlgorithm] = "tsne",
distance_threshold: float = None,
score_threshold: float = None,
min_cluster_size: int = 60,
max_iterations: int = 10,
linkage: str = "average",
**kwargs,
):
self.use_xvector = False
self.ivector_extractor = None
self.ivector_extractor_path = ivector_extractor_path
if ivector_extractor_path == "speechbrain":
if not FOUND_SPEECHBRAIN:
logger.error(
"Could not import speechbrain, please ensure it is installed via `pip install speechbrain`"
)
sys.exit(1)
self.use_xvector = True
else:
self.ivector_extractor = IvectorExtractorModel(ivector_extractor_path)
kwargs.update(self.ivector_extractor.parameters)
super().__init__(**kwargs)
self.expected_num_speakers = expected_num_speakers
self.cluster = cluster
self.metric = DistanceMetric[metric]
self.cuda = cuda
self.cluster_type = ClusterType[cluster_type]
self.manifold_algorithm = ManifoldAlgorithm[manifold_algorithm]
self.distance_threshold = distance_threshold
self.score_threshold = score_threshold
if self.distance_threshold is None:
if self.use_xvector:
self.distance_threshold = 0.25
self.evaluation_mode = evaluation_mode
self.min_cluster_size = min_cluster_size
self.linkage = linkage
self.use_pca = use_pca
self.max_iterations = max_iterations
self.current_labels = []
self.classification_score = None
self.initial_plda_score_threshold = 0
self.plda_score_threshold = 10
self.initial_sb_score_threshold = 0.25
self.ground_truth_utt2spk = {}
self.ground_truth_speakers = {}
self.single_clusters = set()
[docs]
@classmethod
def parse_parameters(
cls,
config_path: Optional[str] = None,
args: Optional[Dict[str, Any]] = None,
unknown_args: Optional[List[str]] = None,
) -> MetaDict:
"""
Parse parameters for speaker classification from a config path or command-line arguments
Parameters
----------
config_path: str
Config path
args: dict[str, Any]
Parsed arguments
unknown_args: list[str]
Optional list of arguments that were not parsed
Returns
-------
dict[str, Any]
Configuration parameters
"""
global_params = {}
if config_path and os.path.exists(config_path):
data = load_configuration(config_path)
for k, v in data.items():
if k == "features":
if "type" in v:
v["feature_type"] = v["type"]
del v["type"]
global_params.update(v)
else:
if v is None and k in cls.nullable_fields:
v = []
global_params[k] = v
global_params.update(cls.parse_args(args, unknown_args))
return global_params
# noinspection PyTypeChecker
[docs]
def setup(self) -> None:
"""
Sets up the corpus and speaker classifier
Raises
------
:class:`~montreal_forced_aligner.exceptions.KaldiProcessingError`
If there were any errors in running Kaldi binaries
"""
if self.initialized:
return
super().setup()
self.create_new_current_workflow(WorkflowType.speaker_diarization)
wf = self.current_workflow
if wf.done:
logger.info("Diarization already done, skipping initialization.")
return
log_dir = os.path.join(self.working_directory, "log")
os.makedirs(log_dir, exist_ok=True)
try:
if self.ivector_extractor is None: # Download models if needed
_ = EncoderClassifier.from_hparams(
source="speechbrain/spkrec-ecapa-voxceleb",
savedir=os.path.join(
GLOBAL_CONFIG.current_profile.temporary_directory,
"models",
"EncoderClassifier",
),
)
_ = SpeakerRecognition.from_hparams(
source="speechbrain/spkrec-ecapa-voxceleb",
savedir=os.path.join(
GLOBAL_CONFIG.current_profile.temporary_directory,
"models",
"SpeakerRecognition",
),
)
self.initialize_database()
self._load_corpus()
self.initialize_jobs()
self.load_embeddings()
if self.cluster:
self.compute_speaker_embeddings()
else:
if not self.has_ivectors():
if self.ivector_extractor.meta["version"] < "2.1":
logger.warning(
"The ivector extractor was trained in an earlier version of MFA. "
"There may be incompatibilities in feature generation that cause errors. "
"Please download the latest version of the model via `mfa model download`, "
"use a different ivector extractor, or use version 2.0.6 of MFA."
)
self.ivector_extractor.export_model(self.working_directory)
self.load_corpus()
self.extract_ivectors()
self.compute_speaker_ivectors()
if self.evaluation_mode:
self.ground_truth_utt2spk = {}
with self.session() as session:
query = session.query(Utterance.id, Utterance.speaker_id, Speaker.name).join(
Utterance.speaker
)
for u_id, s_id, name in query:
self.ground_truth_utt2spk[u_id] = s_id
self.ground_truth_speakers[s_id] = name
except Exception as e:
if isinstance(e, KaldiProcessingError):
log_kaldi_errors(e.error_logs)
e.update_log_file()
raise
self.initialized = True
[docs]
def plda_classification_arguments(self) -> List[PldaClassificationArguments]:
"""
Generate Job arguments for :class:`~montreal_forced_aligner.diarization.multiprocessing.PldaClassificationFunction`
Returns
-------
list[:class:`~montreal_forced_aligner.diarization.multiprocessing.PldaClassificationArguments`]
Arguments for processing
"""
return [
PldaClassificationArguments(
j.id,
getattr(self, "db_string", ""),
os.path.join(self.working_log_directory, f"plda_classification.{j.id}.log"),
self.plda,
self.speaker_ivector_path,
self.num_utts_path,
self.use_xvector,
)
for j in self.jobs
]
[docs]
def classify_speakers(self):
"""Classify speakers based on ivector or speechbrain model"""
self.setup()
logger.info("Classifying utterances...")
with self.session() as session, tqdm.tqdm(
total=self.num_utterances, disable=GLOBAL_CONFIG.quiet
) as pbar, mfa_open(
os.path.join(self.working_directory, "speaker_classification_results.csv"), "w"
) as f:
writer = csv.DictWriter(f, ["utt_id", "file", "begin", "end", "speaker", "score"])
writer.writeheader()
file_names = {
k: v for k, v in session.query(Utterance.id, File.name).join(Utterance.file)
}
utterance_times = {
k: (b, e)
for k, b, e in session.query(Utterance.id, Utterance.begin, Utterance.end)
}
utterance_mapping = []
next_speaker_id = self.get_next_primary_key(Speaker)
speaker_mapping = {}
existing_speakers = {
name: s_id for s_id, name in session.query(Speaker.id, Speaker.name)
}
self.classification_score = 0
if session.query(Speaker).filter(Speaker.name == "MFA_UNKNOWN").first() is None:
session.add(Speaker(id=next_speaker_id, name="MFA_UNKNOWN"))
session.commit()
next_speaker_id += 1
unknown_speaker_id = (
session.query(Speaker).filter(Speaker.name == "MFA_UNKNOWN").first().id
)
if self.use_xvector:
arguments = [
SpeechbrainArguments(j.id, self.db_string, None, self.cuda, self.cluster)
for j in self.jobs
]
func = SpeechbrainClassificationFunction
else:
plda_transform_path = os.path.join(self.working_directory, "plda.pkl")
with open(plda_transform_path, "rb") as f:
self.plda: PldaModel = pickle.load(f)
arguments = self.plda_classification_arguments()
func = PldaClassificationFunction
for utt_id, classified_speaker, score in run_kaldi_function(
func, arguments, pbar.update
):
classified_speaker = str(classified_speaker)
self.classification_score += score / self.num_utterances
if self.score_threshold is not None and score < self.score_threshold:
speaker_id = unknown_speaker_id
elif classified_speaker in existing_speakers:
speaker_id = existing_speakers[classified_speaker]
else:
if classified_speaker not in speaker_mapping:
speaker_mapping[classified_speaker] = {
"id": next_speaker_id,
"name": classified_speaker,
}
next_speaker_id += 1
speaker_id = speaker_mapping[classified_speaker]["id"]
utterance_mapping.append({"id": utt_id, "speaker_id": speaker_id})
line = {
"utt_id": utt_id,
"file": file_names[utt_id],
"begin": utterance_times[utt_id][0],
"end": utterance_times[utt_id][1],
"speaker": classified_speaker,
"score": score,
}
writer.writerow(line)
if self.stopped.stop_check():
logger.debug("Stopping clustering early.")
return
if speaker_mapping:
session.bulk_insert_mappings(Speaker, list(speaker_mapping.values()))
session.flush()
session.commit()
bulk_update(session, Utterance, utterance_mapping)
session.commit()
if not self.evaluation_mode:
self.clean_up_unknown_speaker()
self.fix_speaker_ordering()
if not self.evaluation_mode:
self.cleanup_empty_speakers()
self.refresh_speaker_vectors()
if self.evaluation_mode:
self.evaluate_classification()
def map_speakers_to_ground_truth(self):
with self.session() as session:
utterances = session.query(Utterance.id, Utterance.speaker_id)
labels = []
utterance_ids = []
for utt_id, s_id in utterances:
utterance_ids.append(utt_id)
labels.append(s_id)
ground_truth = np.array([self.ground_truth_utt2spk[x] for x in utterance_ids])
cluster_labels = np.unique(labels)
ground_truth_labels = np.unique(ground_truth)
cm = np.zeros((cluster_labels.shape[0], ground_truth_labels.shape[0]), dtype="int16")
for y_pred, y in zip(labels, ground_truth):
if y_pred < 0:
continue
cm[np.where(cluster_labels == y_pred), np.where(ground_truth_labels == y)] += 1
cm_argmax = cm.argmax(axis=1)
label_to_ground_truth_mapping = {}
for i in range(cluster_labels.shape[0]):
label_to_ground_truth_mapping[int(cluster_labels[i])] = int(
ground_truth_labels[cm_argmax[i]]
)
return label_to_ground_truth_mapping
[docs]
def evaluate_clustering(self) -> None:
"""Compute clustering metric scores and output clustering evaluation results"""
label_to_ground_truth_mapping = self.map_speakers_to_ground_truth()
with self.session() as session, mfa_open(
os.path.join(self.working_directory, "diarization_evaluation_results.csv"), "w"
) as f:
writer = csv.DictWriter(
f,
fieldnames=[
"file",
"begin",
"end",
"text",
"predicted_speaker",
"ground_truth_speaker",
],
)
writer.writeheader()
predicted_utt2spk = {}
query = session.query(
Utterance.id,
File.name,
Utterance.begin,
Utterance.end,
Utterance.text,
Utterance.speaker_id,
).join(Utterance.file)
for u_id, file_name, begin, end, text, s_id in query:
s_id = label_to_ground_truth_mapping[s_id]
predicted_utt2spk[u_id] = s_id
writer.writerow(
{
"file": file_name,
"begin": begin,
"end": end,
"text": text,
"predicted_speaker": self.ground_truth_speakers[s_id],
"ground_truth_speaker": self.ground_truth_speakers[
self.ground_truth_utt2spk[u_id]
],
}
)
ground_truth_labels = np.array([v for v in self.ground_truth_utt2spk.values()])
predicted_labels = np.array(
[predicted_utt2spk[k] for k in self.ground_truth_utt2spk.keys()]
)
rand_score = metrics.adjusted_rand_score(ground_truth_labels, predicted_labels)
ami_score = metrics.adjusted_mutual_info_score(ground_truth_labels, predicted_labels)
nmi_score = metrics.normalized_mutual_info_score(ground_truth_labels, predicted_labels)
homogeneity_score = metrics.homogeneity_score(ground_truth_labels, predicted_labels)
completeness_score = metrics.completeness_score(ground_truth_labels, predicted_labels)
v_measure_score = metrics.v_measure_score(ground_truth_labels, predicted_labels)
fm_score = metrics.fowlkes_mallows_score(ground_truth_labels, predicted_labels)
logger.info(f"Adjusted Rand index score (0-1, higher is better): {rand_score:.4f}")
logger.info(f"Normalized Mutual Information score (perfect=1.0): {nmi_score:.4f}")
logger.info(f"Adjusted Mutual Information score (perfect=1.0): {ami_score:.4f}")
logger.info(f"Homogeneity score (0-1, higher is better): {homogeneity_score:.4f}")
logger.info(f"Completeness score (0-1, higher is better): {completeness_score:.4f}")
logger.info(f"V measure score (0-1, higher is better): {v_measure_score:.4f}")
logger.info(f"Fowlkes-Mallows score (0-1, higher is better): {fm_score:.4f}")
[docs]
def evaluate_classification(self) -> None:
"""Evaluate and output classification accuracy"""
label_to_ground_truth_mapping = self.map_speakers_to_ground_truth()
with self.session() as session, mfa_open(
os.path.join(self.working_directory, "diarization_evaluation_results.csv"), "w"
) as f:
writer = csv.DictWriter(
f,
fieldnames=[
"file",
"begin",
"end",
"text",
"predicted_speaker",
"ground_truth_speaker",
],
)
writer.writeheader()
predicted_utt2spk = {}
query = session.query(
Utterance.id,
File.name,
Utterance.begin,
Utterance.end,
Utterance.text,
Utterance.speaker_id,
).join(Utterance.file)
for u_id, file_name, begin, end, text, s_id in query:
s_id = label_to_ground_truth_mapping[s_id]
predicted_utt2spk[u_id] = s_id
writer.writerow(
{
"file": file_name,
"begin": begin,
"end": end,
"text": text,
"predicted_speaker": self.ground_truth_speakers[s_id],
"ground_truth_speaker": self.ground_truth_speakers[
self.ground_truth_utt2spk[u_id]
],
}
)
ground_truth_labels = np.array([v for v in self.ground_truth_utt2spk.values()])
predicted_labels = np.array(
[
predicted_utt2spk[k] if k in predicted_utt2spk else -1
for k in self.ground_truth_utt2spk.keys()
]
)
precision_score = metrics.precision_score(
ground_truth_labels, predicted_labels, average="weighted"
)
recall_score = metrics.recall_score(
ground_truth_labels, predicted_labels, average="weighted"
)
f1_score = metrics.f1_score(ground_truth_labels, predicted_labels, average="weighted")
logger.info(f"Precision (0-1): {precision_score:.4f}")
logger.info(f"Recall (0-1): {recall_score:.4f}")
logger.info(f"F1 (0-1): {f1_score:.4f}")
@property
def num_utts_path(self) -> str:
"""Path to archive containing number of per training speaker"""
return os.path.join(self.working_directory, "num_utts.ark")
@property
def speaker_ivector_path(self) -> str:
"""Path to archive containing training speaker ivectors"""
return os.path.join(self.working_directory, "speaker_ivectors.ark")
def visualize_clusters(self, ivectors, cluster_labels=None):
import seaborn as sns
from matplotlib import pyplot as plt
sns.set()
metric = self.metric
if metric is DistanceMetric.plda:
metric = DistanceMetric.cosine
points = visualize_clusters(ivectors, self.manifold_algorithm, metric, 10, self.plda)
fig = plt.figure(1)
ax = fig.add_subplot(111)
if cluster_labels is not None:
unique_labels = np.unique(cluster_labels)
num_unique_labels = unique_labels.shape[0]
has_noise = 0 in set(unique_labels)
if has_noise:
num_unique_labels -= 1
cm = sns.color_palette("tab20", num_unique_labels)
for cluster in unique_labels:
if cluster == -1:
color = "k"
name = "Noise"
alpha = 0.75
else:
name = cluster
if not isinstance(name, str):
name = f"Cluster {name}"
cluster_id = cluster
else:
cluster_id = np.where(unique_labels == cluster)[0][0]
if has_noise:
color = cm[cluster_id - 1]
else:
color = cm[cluster_id]
alpha = 1.0
idx = np.where(cluster_labels == cluster)
ax.scatter(points[idx, 0], points[idx, 1], color=color, label=name, alpha=alpha)
else:
ax.scatter(points[:, 0], points[:, 1])
handles, labels = ax.get_legend_handles_labels()
fig.subplots_adjust(bottom=0.3, wspace=0.33)
plt.axis("off")
lgd = ax.legend(
handles,
labels,
loc="upper center",
bbox_to_anchor=(0.5, -0.1),
fancybox=True,
shadow=True,
ncol=5,
)
plot_path = os.path.join(self.working_directory, "cluster_plot.png")
plt.savefig(plot_path, bbox_extra_artists=(lgd,), bbox_inches="tight", transparent=True)
if GLOBAL_CONFIG.current_profile.verbose:
plt.show(block=False)
plt.pause(10)
logger.debug(f"Closing cluster plot, it has been saved to {plot_path}.")
plt.close()
def export_xvectors(self):
logger.info("Exporting SpeechBrain embeddings...")
os.makedirs(self.split_directory, exist_ok=True)
with tqdm.tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar:
arguments = [
ExportIvectorsArguments(
j.id,
self.db_string,
j.construct_path(self.working_log_directory, "export_ivectors", "log"),
self.use_xvector,
)
for j in self.jobs
]
utterance_mapping = []
for utt_id, ark_path in run_kaldi_function(
ExportIvectorsFunction, arguments, pbar.update
):
utterance_mapping.append({"id": utt_id, "ivector_ark": ark_path})
with self.session() as session:
bulk_update(session, Utterance, utterance_mapping)
session.commit()
self._write_ivectors()
def fix_speaker_ordering(self):
with self.session() as session:
query = (
session.query(Speaker.id, File.id)
.join(Utterance.speaker)
.join(Utterance.file)
.distinct()
)
speaker_ordering_mapping = []
for s_id, f_id in query:
speaker_ordering_mapping.append({"speaker_id": s_id, "file_id": f_id, "index": 10})
session.execute(sqlalchemy.delete(SpeakerOrdering))
session.flush()
session.execute(
sqlalchemy.dialects.postgresql.insert(SpeakerOrdering)
.values(speaker_ordering_mapping)
.on_conflict_do_nothing()
)
session.commit()
def initialize_mfa_clustering(self):
with self.session() as session:
next_speaker_id = self.get_next_primary_key(Speaker)
speaker_mapping = {}
existing_speakers = {
name: s_id for s_id, name in session.query(Speaker.id, Speaker.name)
}
utterance_mapping = []
self.classification_score = 0
unk_count = 0
if self.use_xvector:
arguments = [
SpeechbrainArguments(j.id, self.db_string, None, self.cuda, self.cluster)
for j in self.jobs
]
func = SpeechbrainClassificationFunction
score_threshold = self.initial_sb_score_threshold
self.export_xvectors()
else:
plda_transform_path = os.path.join(self.working_directory, "plda.pkl")
with open(plda_transform_path, "rb") as f:
self.plda: PldaModel = pickle.load(f)
arguments = self.plda_classification_arguments()
func = PldaClassificationFunction
score_threshold = self.initial_plda_score_threshold
logger.info("Generating initial speaker labels...")
utt2spk = {k: v for k, v in session.query(Utterance.id, Utterance.speaker_id)}
with tqdm.tqdm(total=self.num_utterances, disable=GLOBAL_CONFIG.quiet) as pbar:
for utt_id, classified_speaker, score in run_kaldi_function(
func, arguments, pbar.update
):
classified_speaker = str(classified_speaker)
self.classification_score += score / self.num_utterances
if score < score_threshold:
unk_count += 1
utterance_mapping.append(
{"id": utt_id, "speaker_id": existing_speakers["MFA_UNKNOWN"]}
)
continue
if classified_speaker in existing_speakers:
speaker_id = existing_speakers[classified_speaker]
else:
if classified_speaker not in speaker_mapping:
speaker_mapping[classified_speaker] = {
"id": next_speaker_id,
"name": classified_speaker,
}
next_speaker_id += 1
speaker_id = speaker_mapping[classified_speaker]["id"]
if speaker_id == utt2spk[utt_id]:
continue
utterance_mapping.append({"id": utt_id, "speaker_id": speaker_id})
if speaker_mapping:
session.bulk_insert_mappings(Speaker, list(speaker_mapping.values()))
session.flush()
session.commit()
session.execute(sqlalchemy.text("DROP INDEX IF EXISTS ix_utterance_speaker_id"))
session.execute(sqlalchemy.text("DROP INDEX IF EXISTS utterance_position_index"))
session.commit()
bulk_update(session, Utterance, utterance_mapping)
session.execute(
sqlalchemy.text(
"CREATE INDEX IF NOT EXISTS ix_utterance_speaker_id on utterance(speaker_id)"
)
)
session.execute(
sqlalchemy.text(
'CREATE INDEX IF NOT EXISTS utterance_position_index on utterance(file_id, speaker_id, begin, "end", channel)'
)
)
session.commit()
self.breakup_large_clusters()
self.cleanup_empty_speakers()
def export_speaker_ivectors(self):
logger.info("Exporting current speaker ivectors...")
with self.session() as session, tqdm.tqdm(
total=self.num_speakers, disable=GLOBAL_CONFIG.quiet
) as pbar, mfa_open(self.num_utts_path, "w") as f:
if self.use_xvector:
ivector_column = Speaker.xvector
else:
ivector_column = Speaker.ivector
speakers = (
session.query(Speaker.id, ivector_column, sqlalchemy.func.count(Utterance.id))
.join(Speaker.utterances)
.filter(Speaker.name != "MFA_UNKNOWN")
.group_by(Speaker.id)
.order_by(Speaker.id)
)
input_proc = subprocess.Popen(
[
thirdparty_binary("copy-vector"),
"--binary=true",
"ark,t:-",
f"ark:{self.speaker_ivector_path}",
],
stdin=subprocess.PIPE,
stderr=subprocess.DEVNULL,
env=os.environ,
)
for s_id, ivector, utterance_count in speakers:
if ivector is None:
continue
ivector = " ".join([format(x, ".12g") for x in ivector])
in_line = f"{s_id} [ {ivector} ]\n".encode("utf8")
input_proc.stdin.write(in_line)
input_proc.stdin.flush()
pbar.update(1)
f.write(f"{s_id} {utterance_count}\n")
input_proc.stdin.close()
input_proc.wait()
def classify_iteration(self, iteration=None) -> None:
logger.info("Classifying utterances...")
low_count = None
if iteration is not None and self.min_cluster_size:
low_count = np.linspace(0, self.min_cluster_size, self.max_iterations)[iteration]
logger.debug(f"Minimum size: {low_count}")
score_threshold = self.plda_score_threshold
if iteration is not None:
score_threshold = np.linspace(
self.initial_plda_score_threshold,
self.plda_score_threshold,
self.max_iterations,
)[iteration]
logger.debug(f"Score threshold: {score_threshold}")
with self.session() as session, tqdm.tqdm(
total=self.num_utterances, disable=GLOBAL_CONFIG.quiet
) as pbar:
unknown_speaker_id = (
session.query(Speaker.id).filter(Speaker.name == "MFA_UNKNOWN").first()[0]
)
utterance_mapping = []
self.classification_score = 0
plda_transform_path = os.path.join(self.working_directory, "plda.pkl")
with open(plda_transform_path, "rb") as f:
self.plda: PldaModel = pickle.load(f)
arguments = self.plda_classification_arguments()
func = PldaClassificationFunction
utt2spk = {k: v for k, v in session.query(Utterance.id, Utterance.speaker_id)}
for utt_id, classified_speaker, score in run_kaldi_function(
func, arguments, pbar.update
):
self.classification_score += score / self.num_utterances
if score < score_threshold:
speaker_id = unknown_speaker_id
else:
speaker_id = classified_speaker
if speaker_id == utt2spk[utt_id]:
continue
utterance_mapping.append({"id": utt_id, "speaker_id": speaker_id})
logger.debug(f"Updating {len(utterance_mapping)} utterances with new speakers")
session.commit()
session.execute(sqlalchemy.text("DROP INDEX IF EXISTS ix_utterance_speaker_id"))
session.execute(sqlalchemy.text("DROP INDEX IF EXISTS utterance_position_index"))
session.commit()
bulk_update(session, Utterance, utterance_mapping)
session.execute(
sqlalchemy.text(
"CREATE INDEX IF NOT EXISTS ix_utterance_speaker_id on utterance(speaker_id)"
)
)
session.execute(
sqlalchemy.text(
'CREATE INDEX IF NOT EXISTS utterance_position_index on utterance(file_id, speaker_id, begin, "end", channel)'
)
)
session.commit()
if iteration is not None and iteration < self.max_iterations - 2:
self.breakup_large_clusters()
self.cleanup_empty_speakers(low_count)
def breakup_large_clusters(self):
with self.session() as session:
unknown_speaker_id = (
session.query(Speaker.id).filter(Speaker.name == "MFA_UNKNOWN").first()[0]
)
sq = (
session.query(Speaker.id, sqlalchemy.func.count().label("utterance_count"))
.join(Speaker.utterances)
.filter(Speaker.id != unknown_speaker_id)
.group_by(Speaker.id)
)
above_threshold_speakers = [unknown_speaker_id]
threshold = 500
for s_id, utterance_count in sq:
if threshold and utterance_count > threshold and s_id not in self.single_clusters:
above_threshold_speakers.append(s_id)
logger.info("Breaking up large speakers...")
logger.debug(f"Unknown speaker is {unknown_speaker_id}")
next_speaker_id = self.get_next_primary_key(Speaker)
with tqdm.tqdm(
total=len(above_threshold_speakers), disable=GLOBAL_CONFIG.quiet
) as pbar:
utterance_mapping = []
new_speakers = {}
for s_id in above_threshold_speakers:
logger.debug(f"Breaking up {s_id}")
query = session.query(Utterance.id, Utterance.plda_vector).filter(
Utterance.plda_vector != None, Utterance.speaker_id == s_id # noqa
)
pbar.update(1)
ivectors = np.empty((query.count(), PLDA_DIMENSION))
logger.debug(f"Had {ivectors.shape[0]} utterances.")
if ivectors.shape[0] == 0:
continue
utterance_ids = []
for i, (u_id, ivector) in enumerate(query):
if self.stopped.stop_check():
break
utterance_ids.append(u_id)
ivectors[i, :] = ivector
if ivectors.shape[0] < self.min_cluster_size:
continue
labels = cluster_matrix(
ivectors,
ClusterType.optics,
metric=DistanceMetric.cosine,
strict=False,
no_visuals=True,
working_directory=self.working_directory,
distance_threshold=0.25,
)
unique, counts = np.unique(labels, return_counts=True)
num_clusters = unique.shape[0]
counts = dict(zip(unique, counts))
logger.debug(f"{num_clusters} clusters found: {counts}")
if num_clusters == 1:
if s_id != unknown_speaker_id:
logger.debug(f"Deleting {s_id} due to no clusters found")
session.execute(
sqlalchemy.update(Utterance)
.filter(Utterance.speaker_id == s_id)
.values({Utterance.speaker_id: unknown_speaker_id})
)
session.flush()
continue
if num_clusters == 2:
if s_id != unknown_speaker_id:
logger.debug(
f"Only found one cluster for {s_id} will skip in the future"
)
self.single_clusters.add(s_id)
continue
for i, utt_id in enumerate(utterance_ids):
label = labels[i]
if label == -1:
speaker_id = unknown_speaker_id
else:
if s_id in self.single_clusters:
continue
if label not in new_speakers:
if s_id == unknown_speaker_id:
label = self._unknown_speaker_break_up_count
self._unknown_speaker_break_up_count += 1
new_speakers[label] = {
"id": next_speaker_id,
"name": f"{s_id}_{label}",
}
next_speaker_id += 1
speaker_id = new_speakers[label]["id"]
utterance_mapping.append({"id": utt_id, "speaker_id": speaker_id})
if new_speakers:
session.bulk_insert_mappings(Speaker, list(new_speakers.values()))
session.commit()
if utterance_mapping:
bulk_update(session, Utterance, utterance_mapping)
session.commit()
logger.debug(f"Broke speakers into {len(new_speakers)} new speakers.")
def cleanup_empty_speakers(self, threshold=None):
with self.session() as session:
session.execute(sqlalchemy.delete(SpeakerOrdering))
session.flush()
unknown_speaker_id = (
session.query(Speaker.id).filter(Speaker.name == "MFA_UNKNOWN").first()[0]
)
non_empty_speakers = [unknown_speaker_id]
sq = (
session.query(Speaker.id, sqlalchemy.func.count().label("utterance_count"))
.join(Speaker.utterances)
.filter(Speaker.id != unknown_speaker_id)
.group_by(Speaker.id)
)
below_threshold_speakers = []
for s_id, utterance_count in sq:
if threshold and utterance_count < threshold:
below_threshold_speakers.append(s_id)
continue
non_empty_speakers.append(s_id)
session.execute(
sqlalchemy.update(Utterance)
.where(Utterance.speaker_id.in_(below_threshold_speakers))
.values(speaker_id=unknown_speaker_id)
)
session.execute(sqlalchemy.delete(Speaker).where(~Speaker.id.in_(non_empty_speakers)))
session.commit()
self._num_speakers = session.query(Speaker).count()
conn = self.db_engine.connect()
try:
conn.execution_options(isolation_level="AUTOCOMMIT")
conn.execute(
sqlalchemy.text(f"ANALYZE {Speaker.__tablename__}, {Utterance.__tablename__}")
)
finally:
conn.close()
[docs]
def cluster_utterances_mfa(self) -> None:
"""
Cluster utterances with a ivector or speechbrain model
"""
self.cluster = False
self.setup()
with self.session() as session:
if session.query(Speaker).filter(Speaker.name == "MFA_UNKNOWN").first() is None:
session.add(Speaker(id=self.get_next_primary_key(Speaker), name="MFA_UNKNOWN"))
session.commit()
self.initialize_mfa_clustering()
with self.session() as session:
uncategorized_count = (
session.query(Utterance)
.join(Utterance.speaker)
.filter(Speaker.name == "MFA_UNKNOWN")
.count()
)
if self.use_xvector:
logger.info(f"Initial average cosine score {self.classification_score:.4f}")
else:
logger.info(f"Initial average PLDA score {self.classification_score:.4f}")
logger.info(f"Number of speakers: {self.num_speakers}")
logger.info(f"Unclassified utterances: {uncategorized_count}")
self._unknown_speaker_break_up_count = 0
for i in range(self.max_iterations):
logger.info(f"Iteration {i}:")
current_score = self.classification_score
self._write_ivectors()
self.compute_plda()
self.refresh_plda_vectors()
self.refresh_speaker_vectors()
self.export_speaker_ivectors()
self.classify_iteration(i)
improvement = self.classification_score - current_score
with self.session() as session:
uncategorized_count = (
session.query(Utterance)
.join(Utterance.speaker)
.filter(Speaker.name == "MFA_UNKNOWN")
.count()
)
logger.info(f"Average PLDA score {self.classification_score:.4f}")
logger.info(f"Improvement: {improvement:.4f}")
logger.info(f"Number of speakers: {self.num_speakers}")
logger.info(f"Unclassified utterances: {uncategorized_count}")
logger.debug(f"Found {self.num_speakers} clusters")
if GLOBAL_CONFIG.current_profile.debug and self.num_utterances < 100000:
self.visualize_current_clusters()
def visualize_current_clusters(self):
with self.session() as session:
query = (
session.query(Speaker.name, Utterance.plda_vector)
.join(Utterance.speaker)
.filter(Utterance.plda_vector is not None)
)
dim = PLDA_DIMENSION
num_utterances = query.count()
if num_utterances == 0:
if self.use_xvector:
column = Utterance.xvector
dim = XVECTOR_DIMENSION
else:
column = Utterance.ivector
dim = IVECTOR_DIMENSION
query = (
session.query(Speaker.name, column)
.join(Utterance.speaker)
.filter(column is not None)
)
num_utterances = query.count()
if num_utterances == 0:
logger.warning("No ivectors/xvectors to visualize")
return
ivectors = np.empty((query.count(), dim))
labels = []
for s_name, ivector in query:
ivectors[len(labels), :] = ivector
labels.append(s_name)
self.visualize_clusters(ivectors, labels)
[docs]
def cluster_utterances(self) -> None:
"""
Cluster utterances with a ivector or speechbrain model
"""
if self.cluster_type is ClusterType.mfa:
self.cluster_utterances_mfa()
self.fix_speaker_ordering()
if not self.evaluation_mode:
self.cleanup_empty_speakers()
self.refresh_speaker_vectors()
if self.evaluation_mode:
self.evaluate_clustering()
return
self.setup()
os.environ["OMP_NUM_THREADS"] = f"{GLOBAL_CONFIG.current_profile.num_jobs}"
os.environ["OPENBLAS_NUM_THREADS"] = f"{GLOBAL_CONFIG.current_profile.num_jobs}"
os.environ["MKL_NUM_THREADS"] = f"{GLOBAL_CONFIG.current_profile.num_jobs}"
if self.metric is DistanceMetric.plda:
plda_transform_path = os.path.join(self.working_directory, "plda.pkl")
with open(plda_transform_path, "rb") as f:
self.plda: PldaModel = pickle.load(f)
if self.evaluation_mode and GLOBAL_CONFIG.current_profile.debug:
self.calculate_eer()
logger.info("Clustering utterances (this may take a while, please be patient)...")
with self.session() as session:
if self.use_pca:
query = session.query(Utterance.id, Utterance.plda_vector).filter(
Utterance.plda_vector != None # noqa
)
ivectors = np.empty((query.count(), PLDA_DIMENSION))
elif self.use_xvector:
query = session.query(Utterance.id, Utterance.xvector).filter(
Utterance.xvector != None # noqa
)
ivectors = np.empty((query.count(), XVECTOR_DIMENSION))
else:
query = session.query(Utterance.id, Utterance.ivector).filter(
Utterance.ivector != None # noqa
)
ivectors = np.empty((query.count(), IVECTOR_DIMENSION))
utterance_ids = []
for i, (u_id, ivector) in enumerate(query):
if self.stopped.stop_check():
break
utterance_ids.append(u_id)
ivectors[i, :] = ivector
num_utterances = ivectors.shape[0]
kwargs = {}
if self.stopped.stop_check():
logger.debug("Stopping clustering early.")
return
kwargs["min_cluster_size"] = self.min_cluster_size
kwargs["distance_threshold"] = self.distance_threshold
if self.cluster_type is ClusterType.agglomerative:
kwargs["memory"] = MEMORY
kwargs["linkage"] = self.linkage
kwargs["n_clusters"] = self.expected_num_speakers
if not self.expected_num_speakers:
kwargs["n_clusters"] = None
elif self.cluster_type is ClusterType.spectral:
kwargs["n_clusters"] = self.expected_num_speakers
elif self.cluster_type is ClusterType.hdbscan:
kwargs["memory"] = MEMORY
elif self.cluster_type is ClusterType.optics:
kwargs["memory"] = MEMORY
elif self.cluster_type is ClusterType.kmeans:
kwargs["n_clusters"] = self.expected_num_speakers
labels = cluster_matrix(
ivectors,
self.cluster_type,
metric=self.metric,
plda=self.plda,
working_directory=self.working_directory,
**kwargs,
)
if self.stopped.stop_check():
logger.debug("Stopping clustering early.")
return
if GLOBAL_CONFIG.current_profile.debug:
self.visualize_clusters(ivectors, labels)
utterance_clusters = collections.defaultdict(list)
for i in range(num_utterances):
u_id = utterance_ids[i]
cluster_id = int(labels[i])
utterance_clusters[cluster_id].append(u_id)
utterance_mapping = []
next_speaker_id = self.get_next_primary_key(Speaker)
speaker_mapping = []
unknown_speaker_id = None
for cluster_id, utterance_ids in sorted(utterance_clusters.items()):
if cluster_id < 0:
if unknown_speaker_id is None:
speaker_name = "MFA_UNKNOWN"
speaker_mapping.append({"id": next_speaker_id, "name": speaker_name})
speaker_id = next_speaker_id
unknown_speaker_id = speaker_id
next_speaker_id += 1
else:
speaker_id = unknown_speaker_id
else:
speaker_name = f"Cluster {cluster_id}"
speaker_mapping.append({"id": next_speaker_id, "name": speaker_name})
speaker_id = next_speaker_id
next_speaker_id += 1
for u_id in utterance_ids:
utterance_mapping.append({"id": u_id, "speaker_id": speaker_id})
if self.stopped.stop_check():
logger.debug("Stopping clustering early.")
return
if speaker_mapping:
session.bulk_insert_mappings(Speaker, speaker_mapping)
session.flush()
session.commit()
bulk_update(session, Utterance, utterance_mapping)
session.flush()
session.commit()
if not self.evaluation_mode:
self.clean_up_unknown_speaker()
self.fix_speaker_ordering()
if not self.evaluation_mode:
self.cleanup_empty_speakers()
self.refresh_speaker_vectors()
if self.evaluation_mode:
self.evaluate_clustering()
os.environ["OMP_NUM_THREADS"] = f"{GLOBAL_CONFIG.current_profile.blas_num_threads}"
os.environ["OPENBLAS_NUM_THREADS"] = f"{GLOBAL_CONFIG.current_profile.blas_num_threads}"
os.environ["MKL_NUM_THREADS"] = f"{GLOBAL_CONFIG.current_profile.blas_num_threads}"
def clean_up_unknown_speaker(self):
with self.session() as session:
unknown_speaker = session.query(Speaker).filter(Speaker.name == "MFA_UNKNOWN").first()
next_speaker_id = self.get_next_primary_key(Speaker)
if unknown_speaker is not None:
speaker_mapping = {}
utterance_mapping = []
query = (
session.query(File.id, File.name)
.join(File.utterances)
.filter(Utterance.speaker_id == unknown_speaker.id)
.distinct()
)
for file_id, file_name in query:
speaker_mapping[file_id] = {"id": next_speaker_id, "name": file_name}
next_speaker_id += 1
query = (
session.query(Utterance.id, Utterance.file_id)
.join(File.utterances)
.filter(Utterance.speaker_id == unknown_speaker.id)
)
for utterance_id, file_id in query:
utterance_mapping.append(
{"id": utterance_id, "speaker_id": speaker_mapping[file_id]["id"]}
)
session.bulk_insert_mappings(Speaker, list(speaker_mapping.values()))
session.flush()
session.execute(
sqlalchemy.delete(SpeakerOrdering).where(
SpeakerOrdering.c.speaker_id == unknown_speaker.id
)
)
session.commit()
bulk_update(session, Utterance, utterance_mapping)
session.commit()
[docs]
def calculate_eer(self) -> typing.Tuple[float, float]:
"""
Calculate Equal Error Rate (EER) and threshold for the diarization metric using the ground truth data.
Returns
-------
float
EER
float
Threshold of EER
"""
if not FOUND_SPEECHBRAIN:
logger.info("No speechbrain found, skipping EER calculation.")
return 0.0, 0.0
logger.info("Calculating EER using ground truth speakers...")
limit_per_speaker = 5
limit_within_speaker = 30
begin = time.time()
with tqdm.tqdm(total=self.num_speakers, disable=GLOBAL_CONFIG.quiet) as pbar:
arguments = [
ComputeEerArguments(
j.id,
self.db_string,
None,
self.plda,
self.metric,
self.use_xvector,
limit_within_speaker,
limit_per_speaker,
)
for j in self.jobs
]
match_scores = []
mismatch_scores = []
for matches, mismatches in run_kaldi_function(
ComputeEerFunction, arguments, pbar.update
):
match_scores.extend(matches)
mismatch_scores.extend(mismatches)
random.shuffle(mismatches)
mismatch_scores = mismatch_scores[: len(match_scores)]
match_scores = np.array(match_scores)
mismatch_scores = np.array(mismatch_scores)
device = torch.device("cuda" if self.cuda else "cpu")
eer, thresh = EER(
torch.tensor(mismatch_scores, device=device),
torch.tensor(match_scores, device=device),
)
logger.debug(
f"Matching scores: {np.min(match_scores):.3f}-{np.max(match_scores):.3f} (mean = {match_scores.mean():.3f}, n = {match_scores.shape[0]})"
)
logger.debug(
f"Mismatching scores: {np.min(mismatch_scores):.3f}-{np.max(mismatch_scores):.3f} (mean = {mismatch_scores.mean():.3f}, n = {mismatch_scores.shape[0]})"
)
logger.info(f"EER: {eer*100:.2f}%")
logger.info(f"Threshold: {thresh:.4f}")
logger.debug(f"Calculating EER took {time.time() - begin:.3f} seconds")
return eer, thresh
[docs]
def load_embeddings(self) -> None:
"""Load embeddings from a speechbrain model"""
if self.has_xvectors():
logger.info("Embeddings already loaded.")
return
logger.info("Loading SpeechBrain embeddings...")
with tqdm.tqdm(
total=self.num_utterances, disable=GLOBAL_CONFIG.quiet
) as pbar, self.session() as session:
begin = time.time()
update_mapping = {}
arguments = [
SpeechbrainArguments(j.id, self.db_string, None, self.cuda, self.cluster)
for j in self.jobs
]
embeddings = []
utterance_ids = []
for u_id, emb in run_kaldi_function(
SpeechbrainEmbeddingFunction, arguments, pbar.update
):
utterance_ids.append(u_id)
embeddings.append(emb)
update_mapping[u_id] = {"id": u_id, "xvector": emb}
embeddings = np.array(embeddings)
if PLDA_DIMENSION != XVECTOR_DIMENSION:
if embeddings.shape[0] < PLDA_DIMENSION:
logger.debug("Can't run PLDA due to too few features.")
else:
pca = decomposition.PCA(PLDA_DIMENSION)
pca.fit(embeddings)
logger.debug(
f"PCA explained variance: {np.sum(pca.explained_variance_ratio_)*100:.2f}%"
)
transformed = pca.transform(embeddings)
for i, u_id in enumerate(utterance_ids):
update_mapping[u_id]["plda_vector"] = transformed[i, :]
else:
for v in update_mapping.values():
v["plda_vector"] = v["xvector"]
bulk_update(session, Utterance, list(update_mapping.values()))
session.flush()
session.execute(
sqlalchemy.text(
"CREATE INDEX IF NOT EXISTS utterance_xvector_index ON utterance USING ivfflat (xvector vector_cosine_ops);"
)
)
session.execute(
sqlalchemy.text(
"CREATE INDEX IF NOT EXISTS utterance_plda_vector_index ON utterance USING ivfflat (plda_vector vector_cosine_ops);"
)
)
session.query(Corpus).update({Corpus.xvectors_loaded: True})
session.commit()
logger.debug(f"Loading embeddings took {time.time() - begin:.3f} seconds")
def refresh_plda_vectors(self):
logger.info("Refreshing PLDA vectors...")
self.plda = PldaModel.load(self.plda_path)
with self.session() as session, tqdm.tqdm(
total=self.num_utterances, disable=GLOBAL_CONFIG.quiet
) as pbar:
if self.use_xvector:
ivector_column = Utterance.xvector
else:
ivector_column = Utterance.ivector
update_mapping = []
utterance_ids = []
ivectors = []
utterances = session.query(Utterance.id, ivector_column).filter(
ivector_column != None # noqa
)
for utt_id, ivector in utterances:
pbar.update(1)
utterance_ids.append(utt_id)
ivectors.append(ivector)
ivectors = np.array(ivectors)
ivectors = self.plda.process_ivectors(ivectors)
for i, utt_id in enumerate(utterance_ids):
update_mapping.append({"id": utt_id, "plda_vector": ivectors[i, :]})
bulk_update(session, Utterance, update_mapping)
session.commit()
plda_transform_path = os.path.join(self.working_directory, "plda.pkl")
with open(plda_transform_path, "wb") as f:
pickle.dump(self.plda, f)
[docs]
def refresh_speaker_vectors(self) -> None:
"""Refresh speaker vectors following clustering or classification"""
logger.info("Refreshing speaker vectors...")
with self.session() as session, tqdm.tqdm(
total=self.num_speakers, disable=GLOBAL_CONFIG.quiet
) as pbar:
if self.use_xvector:
ivector_column = Utterance.xvector
else:
ivector_column = Utterance.ivector
update_mapping = {}
speaker_ids = []
ivectors = []
speakers = session.query(Speaker.id)
for (s_id,) in speakers:
query = session.query(ivector_column).filter(Utterance.speaker_id == s_id)
s_ivectors = []
for (u_ivector,) in query:
s_ivectors.append(u_ivector)
if not s_ivectors:
continue
mean_ivector = np.mean(np.array(s_ivectors), axis=0)
speaker_ids.append(s_id)
ivectors.append(mean_ivector)
if self.use_xvector:
key = "xvector"
else:
key = "ivector"
update_mapping[s_id] = {"id": s_id, key: mean_ivector}
pbar.update(1)
ivectors = np.array(ivectors)
if self.plda is not None:
ivectors = self.plda.process_ivectors(ivectors)
for i, speaker_id in enumerate(speaker_ids):
update_mapping[speaker_id]["plda_vector"] = ivectors[i, :]
bulk_update(session, Speaker, list(update_mapping.values()))
session.commit()
[docs]
def compute_speaker_embeddings(self) -> None:
"""Generate per-speaker embeddings as the mean over their utterances"""
if not self.has_xvectors():
self.load_embeddings()
logger.info("Computing SpeechBrain speaker embeddings...")
with tqdm.tqdm(
total=self.num_speakers, disable=GLOBAL_CONFIG.quiet
) as pbar, self.session() as session:
update_mapping = []
speakers = session.query(Speaker.id)
for (s_id,) in speakers:
u_query = session.query(Utterance.xvector).filter(
Utterance.speaker_id == s_id, Utterance.xvector != None # noqa
)
embeddings = np.empty((u_query.count(), XVECTOR_DIMENSION))
if embeddings.shape[0] == 0:
continue
for i, (xvector,) in enumerate(u_query):
embeddings[i, :] = xvector
speaker_xvector = np.mean(embeddings, axis=0)
update_mapping.append({"id": s_id, "xvector": speaker_xvector})
pbar.update(1)
bulk_update(session, Speaker, update_mapping)
session.commit()
[docs]
def export_files(self, output_directory: str) -> None:
"""
Export files with their new speaker labels
Parameters
----------
output_directory: str
Output directory to save files
"""
if not self.overwrite and os.path.exists(output_directory):
output_directory = os.path.join(self.working_directory, "speaker_classification")
os.makedirs(output_directory, exist_ok=True)
diagnostic_files = [
"diarization_evaluation_results.csv",
"cluster_plot.png",
"nearest_neighbors.png",
]
for fname in diagnostic_files:
path = os.path.join(self.working_directory, fname)
if os.path.exists(path):
shutil.copyfile(
path,
os.path.join(output_directory, fname),
)
with mfa_open(os.path.join(output_directory, "parameters.yaml"), "w") as f:
yaml.safe_dump(
{
"ivector_extractor_path": str(self.ivector_extractor_path),
"expected_num_speakers": self.expected_num_speakers,
"cluster": self.cluster,
"cuda": self.cuda,
"metric": self.metric.name,
"cluster_type": self.cluster_type.name,
"distance_threshold": self.distance_threshold,
"min_cluster_size": self.min_cluster_size,
"linkage": self.linkage,
},
f,
)
with self.session() as session:
logger.info("Writing output files...")
files = session.query(File).options(
selectinload(File.utterances),
selectinload(File.speakers),
joinedload(File.sound_file, innerjoin=True).load_only(SoundFile.duration),
joinedload(File.text_file, innerjoin=True).load_only(TextFile.file_type),
)
with tqdm.tqdm(total=self.num_files, disable=GLOBAL_CONFIG.quiet) as pbar:
for file in files:
utterance_count = len(file.utterances)
if utterance_count == 0:
logger.debug(f"Could not find any utterances for {file.name}")
continue
output_format = file.text_file.file_type
output_path = construct_output_path(
file.name,
file.relative_path,
output_directory,
output_format=output_format,
)
if output_format == "lab":
with mfa_open(output_path, "w") as f:
f.write(file.utterances[0].text)
else:
data = file.construct_transcription_tiers(original_text=True)
export_textgrid(
data,
output_path,
file.duration,
self.export_frame_shift,
output_format,
)
pbar.update(1)