"""
Speaker classification
======================
"""
from __future__ import annotations
import collections
import csv
import logging
import math
import os
import random
import shutil
import subprocess
import sys
import time
import typing
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import numpy as np
import sqlalchemy
import yaml
from _kalpy.ivector import Plda, ivector_normalize_length
from _kalpy.matrix import DoubleVector, FloatVector
from kalpy.utils import read_kaldi_object
from sklearn import metrics
from sqlalchemy.orm import joinedload, selectinload
from tqdm.rich import tqdm
from montreal_forced_aligner import config
from montreal_forced_aligner.abc import FileExporterMixin, TopLevelMfaWorker
from montreal_forced_aligner.config import (
IVECTOR_DIMENSION,
MEMORY,
PLDA_DIMENSION,
XVECTOR_DIMENSION,
)
from montreal_forced_aligner.corpus.features import ExportIvectorsArguments, ExportIvectorsFunction
from montreal_forced_aligner.corpus.ivector_corpus import IvectorCorpusMixin
from montreal_forced_aligner.data import (
ClusterType,
DistanceMetric,
ManifoldAlgorithm,
WorkflowType,
)
from montreal_forced_aligner.db import (
Corpus,
File,
SoundFile,
Speaker,
SpeakerOrdering,
TextFile,
Utterance,
bulk_update,
)
from montreal_forced_aligner.diarization.multiprocessing import (
ComputeEerArguments,
ComputeEerFunction,
PldaClassificationArguments,
PldaClassificationFunction,
SpeechbrainArguments,
SpeechbrainClassificationFunction,
SpeechbrainEmbeddingFunction,
cluster_matrix,
visualize_clusters,
)
from montreal_forced_aligner.exceptions import KaldiProcessingError
from montreal_forced_aligner.helper import load_configuration, mfa_open
from montreal_forced_aligner.models import IvectorExtractorModel
from montreal_forced_aligner.textgrid import construct_output_path, export_textgrid
from montreal_forced_aligner.utils import log_kaldi_errors, run_kaldi_function, thirdparty_binary
try:
import warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore")
torch_logger = logging.getLogger("speechbrain.utils.torch_audio_backend")
torch_logger.setLevel(logging.ERROR)
torch_logger = logging.getLogger("speechbrain.utils.train_logger")
torch_logger.setLevel(logging.ERROR)
import torch
try:
from speechbrain.pretrained import EncoderClassifier, SpeakerRecognition
except ImportError: # speechbrain 1.0
from speechbrain.inference.classifiers import EncoderClassifier
from speechbrain.inference.speaker import SpeakerRecognition
from speechbrain.utils.metric_stats import EER
FOUND_SPEECHBRAIN = True
except (ImportError, OSError):
FOUND_SPEECHBRAIN = False
EncoderClassifier = None
if TYPE_CHECKING:
from montreal_forced_aligner.abc import MetaDict
__all__ = ["SpeakerDiarizer", "FOUND_SPEECHBRAIN"]
logger = logging.getLogger("mfa")
[docs]
class SpeakerDiarizer(IvectorCorpusMixin, TopLevelMfaWorker, FileExporterMixin):
"""
Class for performing speaker classification, not currently very functional, but
is planned to be expanded in the future
Parameters
----------
ivector_extractor_path : str
Path to ivector extractor model, or "speechbrain"
expected_num_speakers: int, optional
Number of speakers in the corpus, if known
cluster: bool
Flag for whether speakers should be clustered instead of classified
evaluation_mode: bool
Flag for evaluating against existing speaker labels
cuda: bool
Flag for using CUDA for speechbrain models
metric: str or :class:`~montreal_forced_aligner.data.DistanceMetric`
One of "cosine", "plda", or "euclidean"
cluster_type: str or :class:`~montreal_forced_aligner.data.ClusterType`
Clustering algorithm
relative_distance_threshold: float
Threshold to use clustering based on distance
"""
def __init__(
self,
ivector_extractor_path: typing.Union[str, Path] = "speechbrain",
expected_num_speakers: int = 0,
cluster: bool = True,
evaluation_mode: bool = False,
cuda: bool = False,
use_pca: bool = True,
metric: typing.Union[str, DistanceMetric] = "cosine",
cluster_type: typing.Union[str, ClusterType] = "hdbscan",
manifold_algorithm: typing.Union[str, ManifoldAlgorithm] = "tsne",
distance_threshold: float = None,
score_threshold: float = None,
min_cluster_size: int = 60,
max_iterations: int = 10,
linkage: str = "average",
**kwargs,
):
self.sliding_cmvn = True
self.use_xvector = False
self.ivector_extractor = None
self.ivector_extractor_path = ivector_extractor_path
if ivector_extractor_path == "speechbrain":
if not FOUND_SPEECHBRAIN:
logger.error(
"Could not import speechbrain, please ensure it is installed via `pip install speechbrain`"
)
sys.exit(1)
self.use_xvector = True
else:
self.ivector_extractor = IvectorExtractorModel(ivector_extractor_path)
kwargs.update(self.ivector_extractor.parameters)
super().__init__(**kwargs)
self.expected_num_speakers = expected_num_speakers
self.cluster = cluster
self.metric = DistanceMetric[metric]
self.cuda = cuda
self.cluster_type = ClusterType[cluster_type]
self.manifold_algorithm = ManifoldAlgorithm[manifold_algorithm]
self.distance_threshold = distance_threshold
self.score_threshold = score_threshold
if self.distance_threshold is None:
if self.use_xvector:
self.distance_threshold = 0.25
self.evaluation_mode = evaluation_mode
self.min_cluster_size = min_cluster_size
self.linkage = linkage
self.use_pca = use_pca
self.max_iterations = max_iterations
self.current_labels = []
self.classification_score = None
self.initial_plda_score_threshold = 0
self.plda_score_threshold = 10
self.initial_sb_score_threshold = 0.75
self.ground_truth_utt2spk = {}
self.ground_truth_speakers = {}
self.single_clusters = set()
self._unknown_speaker_break_up_count = 0
[docs]
@classmethod
def parse_parameters(
cls,
config_path: Optional[Path] = None,
args: Optional[Dict[str, Any]] = None,
unknown_args: Optional[List[str]] = None,
) -> MetaDict:
"""
Parse parameters for speaker classification from a config path or command-line arguments
Parameters
----------
config_path: :class:`~pathlib.Path`
Config path
args: dict[str, Any]
Parsed arguments
unknown_args: list[str]
Optional list of arguments that were not parsed
Returns
-------
dict[str, Any]
Configuration parameters
"""
global_params = {}
if config_path and os.path.exists(config_path):
data = load_configuration(config_path)
for k, v in data.items():
if k == "features":
if "type" in v:
v["feature_type"] = v["type"]
del v["type"]
global_params.update(v)
else:
if v is None and k in cls.nullable_fields:
v = []
global_params[k] = v
global_params.update(cls.parse_args(args, unknown_args))
return global_params
# noinspection PyTypeChecker
[docs]
def setup(self) -> None:
"""
Sets up the corpus and speaker classifier
Raises
------
:class:`~montreal_forced_aligner.exceptions.KaldiProcessingError`
If there were any errors in running Kaldi binaries
"""
if self.initialized:
return
super().setup()
self.create_new_current_workflow(WorkflowType.speaker_diarization)
wf = self.current_workflow
if wf.done:
logger.info("Diarization already done, skipping initialization.")
return
log_dir = self.working_directory.joinpath("log")
os.makedirs(log_dir, exist_ok=True)
try:
if self.ivector_extractor is None: # Download models if needed
_ = EncoderClassifier.from_hparams(
source="speechbrain/spkrec-ecapa-voxceleb",
savedir=os.path.join(
config.TEMPORARY_DIRECTORY,
"models",
"EncoderClassifier",
),
)
_ = SpeakerRecognition.from_hparams(
source="speechbrain/spkrec-ecapa-voxceleb",
savedir=os.path.join(
config.TEMPORARY_DIRECTORY,
"models",
"SpeakerRecognition",
),
)
self.initialize_database()
self._load_corpus()
self.initialize_jobs()
self.load_embeddings()
if self.cluster:
self.compute_speaker_embeddings()
else:
if not self.has_ivectors():
if self.ivector_extractor.meta["version"] < "2.1":
logger.warning(
"The ivector extractor was trained in an earlier version of MFA. "
"There may be incompatibilities in feature generation that cause errors. "
"Please download the latest version of the model via `mfa model download`, "
"use a different ivector extractor, or use version 2.0.6 of MFA."
)
self.ivector_extractor.export_model(self.working_directory)
self.load_corpus()
self.extract_ivectors()
self.compute_speaker_ivectors()
if self.evaluation_mode:
self.ground_truth_utt2spk = {}
with self.session() as session:
query = session.query(Utterance.id, Utterance.speaker_id, Speaker.name).join(
Utterance.speaker
)
for u_id, s_id, name in query:
self.ground_truth_utt2spk[u_id] = s_id
self.ground_truth_speakers[s_id] = name
except Exception as e:
if isinstance(e, KaldiProcessingError):
log_kaldi_errors(e.error_logs)
e.update_log_file()
raise
self.initialized = True
[docs]
def plda_classification_arguments(self) -> List[PldaClassificationArguments]:
"""
Generate Job arguments for :class:`~montreal_forced_aligner.diarization.multiprocessing.PldaClassificationFunction`
Returns
-------
list[:class:`~montreal_forced_aligner.diarization.multiprocessing.PldaClassificationArguments`]
Arguments for processing
"""
return [
PldaClassificationArguments(
j.id,
getattr(self, "session" if config.USE_THREADING else "db_string", ""),
self.working_log_directory.joinpath(f"plda_classification.{j.id}.log"),
self.plda_path,
self.speaker_ivector_path,
self.num_utts_path,
self.use_xvector,
)
for j in self.jobs
]
[docs]
def classify_speakers(self):
"""Classify speakers based on ivector or speechbrain model"""
self.setup()
logger.info("Classifying utterances...")
with self.session() as session, mfa_open(
self.working_directory.joinpath("speaker_classification_results.csv"), "w"
) as f:
writer = csv.DictWriter(f, ["utt_id", "file", "begin", "end", "speaker", "score"])
writer.writeheader()
file_names = {
k: v for k, v in session.query(Utterance.id, File.name).join(Utterance.file)
}
utterance_times = {
k: (b, e)
for k, b, e in session.query(Utterance.id, Utterance.begin, Utterance.end)
}
utterance_mapping = []
next_speaker_id = self.get_next_primary_key(Speaker)
speaker_mapping = {}
existing_speakers = {
name: s_id for s_id, name in session.query(Speaker.id, Speaker.name)
}
self.classification_score = 0
if session.query(Speaker).filter(Speaker.name == "MFA_UNKNOWN").first() is None:
session.add(Speaker(id=next_speaker_id, name="MFA_UNKNOWN"))
session.commit()
next_speaker_id += 1
unknown_speaker_id = (
session.query(Speaker).filter(Speaker.name == "MFA_UNKNOWN").first().id
)
if self.use_xvector:
arguments = [
SpeechbrainArguments(
j.id,
getattr(self, "session" if config.USE_THREADING else "db_string", ""),
None,
self.cuda,
self.cluster,
)
for j in self.jobs
]
func = SpeechbrainClassificationFunction
else:
self.plda = read_kaldi_object(Plda, self.plda_path)
arguments = self.plda_classification_arguments()
func = PldaClassificationFunction
for utt_id, classified_speaker, score in run_kaldi_function(
func, arguments, total_count=self.num_utterances
):
classified_speaker = str(classified_speaker)
self.classification_score += score / self.num_utterances
if self.score_threshold is not None and score < self.score_threshold:
speaker_id = unknown_speaker_id
elif classified_speaker in existing_speakers:
speaker_id = existing_speakers[classified_speaker]
else:
if classified_speaker not in speaker_mapping:
speaker_mapping[classified_speaker] = {
"id": next_speaker_id,
"name": classified_speaker,
}
next_speaker_id += 1
speaker_id = speaker_mapping[classified_speaker]["id"]
utterance_mapping.append({"id": utt_id, "speaker_id": speaker_id})
line = {
"utt_id": utt_id,
"file": file_names[utt_id],
"begin": utterance_times[utt_id][0],
"end": utterance_times[utt_id][1],
"speaker": classified_speaker,
"score": score,
}
writer.writerow(line)
if self.stopped.is_set():
logger.debug("Stopping clustering early.")
return
if speaker_mapping:
session.bulk_insert_mappings(Speaker, list(speaker_mapping.values()))
session.flush()
session.commit()
bulk_update(session, Utterance, utterance_mapping)
session.commit()
if not self.evaluation_mode:
self.clean_up_unknown_speaker()
self.fix_speaker_ordering()
if not self.evaluation_mode:
self.cleanup_empty_speakers()
self.refresh_speaker_vectors()
if self.evaluation_mode:
self.evaluate_classification()
def map_speakers_to_ground_truth(self):
with self.session() as session:
utterances = session.query(Utterance.id, Utterance.speaker_id)
labels = []
utterance_ids = []
for utt_id, s_id in utterances:
utterance_ids.append(utt_id)
labels.append(s_id)
ground_truth = np.array([self.ground_truth_utt2spk[x] for x in utterance_ids])
cluster_labels = np.unique(labels)
ground_truth_labels = np.unique(ground_truth)
cm = np.zeros((cluster_labels.shape[0], ground_truth_labels.shape[0]), dtype="int16")
for y_pred, y in zip(labels, ground_truth):
if y_pred < 0:
continue
cm[np.where(cluster_labels == y_pred), np.where(ground_truth_labels == y)] += 1
cm_argmax = cm.argmax(axis=1)
label_to_ground_truth_mapping = {}
for i in range(cluster_labels.shape[0]):
label_to_ground_truth_mapping[int(cluster_labels[i])] = int(
ground_truth_labels[cm_argmax[i]]
)
return label_to_ground_truth_mapping
[docs]
def evaluate_clustering(self) -> None:
"""Compute clustering metric scores and output clustering evaluation results"""
label_to_ground_truth_mapping = self.map_speakers_to_ground_truth()
with self.session() as session, mfa_open(
self.working_directory.joinpath("diarization_evaluation_results.csv"), "w"
) as f:
writer = csv.DictWriter(
f,
fieldnames=[
"file",
"begin",
"end",
"text",
"predicted_speaker",
"ground_truth_speaker",
],
)
writer.writeheader()
predicted_utt2spk = {}
query = session.query(
Utterance.id,
File.name,
Utterance.begin,
Utterance.end,
Utterance.text,
Utterance.speaker_id,
).join(Utterance.file)
for u_id, file_name, begin, end, text, s_id in query:
s_id = label_to_ground_truth_mapping[s_id]
predicted_utt2spk[u_id] = s_id
writer.writerow(
{
"file": file_name,
"begin": begin,
"end": end,
"text": text,
"predicted_speaker": self.ground_truth_speakers[s_id],
"ground_truth_speaker": self.ground_truth_speakers[
self.ground_truth_utt2spk[u_id]
],
}
)
ground_truth_labels = np.array([v for v in self.ground_truth_utt2spk.values()])
predicted_labels = np.array(
[predicted_utt2spk[k] for k in self.ground_truth_utt2spk.keys()]
)
rand_score = metrics.adjusted_rand_score(ground_truth_labels, predicted_labels)
ami_score = metrics.adjusted_mutual_info_score(ground_truth_labels, predicted_labels)
nmi_score = metrics.normalized_mutual_info_score(ground_truth_labels, predicted_labels)
homogeneity_score = metrics.homogeneity_score(ground_truth_labels, predicted_labels)
completeness_score = metrics.completeness_score(ground_truth_labels, predicted_labels)
v_measure_score = metrics.v_measure_score(ground_truth_labels, predicted_labels)
fm_score = metrics.fowlkes_mallows_score(ground_truth_labels, predicted_labels)
logger.info(f"Adjusted Rand index score (0-1, higher is better): {rand_score:.4f}")
logger.info(f"Normalized Mutual Information score (perfect=1.0): {nmi_score:.4f}")
logger.info(f"Adjusted Mutual Information score (perfect=1.0): {ami_score:.4f}")
logger.info(f"Homogeneity score (0-1, higher is better): {homogeneity_score:.4f}")
logger.info(f"Completeness score (0-1, higher is better): {completeness_score:.4f}")
logger.info(f"V measure score (0-1, higher is better): {v_measure_score:.4f}")
logger.info(f"Fowlkes-Mallows score (0-1, higher is better): {fm_score:.4f}")
[docs]
def evaluate_classification(self) -> None:
"""Evaluate and output classification accuracy"""
label_to_ground_truth_mapping = self.map_speakers_to_ground_truth()
with self.session() as session, mfa_open(
self.working_directory.joinpath("diarization_evaluation_results.csv"), "w"
) as f:
writer = csv.DictWriter(
f,
fieldnames=[
"file",
"begin",
"end",
"text",
"predicted_speaker",
"ground_truth_speaker",
],
)
writer.writeheader()
predicted_utt2spk = {}
query = session.query(
Utterance.id,
File.name,
Utterance.begin,
Utterance.end,
Utterance.text,
Utterance.speaker_id,
).join(Utterance.file)
for u_id, file_name, begin, end, text, s_id in query:
s_id = label_to_ground_truth_mapping[s_id]
predicted_utt2spk[u_id] = s_id
writer.writerow(
{
"file": file_name,
"begin": begin,
"end": end,
"text": text,
"predicted_speaker": self.ground_truth_speakers[s_id],
"ground_truth_speaker": self.ground_truth_speakers[
self.ground_truth_utt2spk[u_id]
],
}
)
ground_truth_labels = np.array([v for v in self.ground_truth_utt2spk.values()])
predicted_labels = np.array(
[
predicted_utt2spk[k] if k in predicted_utt2spk else -1
for k in self.ground_truth_utt2spk.keys()
]
)
precision_score = metrics.precision_score(
ground_truth_labels, predicted_labels, average="weighted"
)
recall_score = metrics.recall_score(
ground_truth_labels, predicted_labels, average="weighted"
)
f1_score = metrics.f1_score(ground_truth_labels, predicted_labels, average="weighted")
logger.info(f"Precision (0-1): {precision_score:.4f}")
logger.info(f"Recall (0-1): {recall_score:.4f}")
logger.info(f"F1 (0-1): {f1_score:.4f}")
@property
def num_utts_path(self) -> str:
"""Path to archive containing number of per training speaker"""
return self.working_directory.joinpath("num_utts.ark")
@property
def speaker_ivector_path(self) -> str:
"""Path to archive containing training speaker ivectors"""
return self.working_directory.joinpath("speaker_ivectors.ark")
def visualize_clusters(self, ivectors, cluster_labels=None):
import seaborn as sns
from matplotlib import pyplot as plt
sns.set()
metric = self.metric
if metric is DistanceMetric.plda:
metric = DistanceMetric.cosine
points = visualize_clusters(ivectors, self.manifold_algorithm, metric, 10, self.plda)
fig = plt.figure(1)
ax = fig.add_subplot(111)
if cluster_labels is not None:
unique_labels = np.unique(cluster_labels)
num_unique_labels = unique_labels.shape[0]
has_noise = 0 in set(unique_labels)
if has_noise:
num_unique_labels -= 1
cm = sns.color_palette("tab20", num_unique_labels)
for cluster in unique_labels:
if cluster == -1:
color = "k"
name = "Noise"
alpha = 0.75
else:
name = cluster
if not isinstance(name, str):
name = f"Cluster {name}"
cluster_id = cluster
else:
cluster_id = np.where(unique_labels == cluster)[0][0]
if has_noise:
color = cm[cluster_id - 1]
else:
color = cm[cluster_id]
alpha = 1.0
idx = np.where(cluster_labels == cluster)
ax.scatter(points[idx, 0], points[idx, 1], color=color, label=name, alpha=alpha)
else:
ax.scatter(points[:, 0], points[:, 1])
handles, labels = ax.get_legend_handles_labels()
fig.subplots_adjust(bottom=0.3, wspace=0.33)
plt.axis("off")
lgd = ax.legend(
handles,
labels,
loc="upper center",
bbox_to_anchor=(0.5, -0.1),
fancybox=True,
shadow=True,
ncol=5,
)
plot_path = self.working_directory.joinpath("cluster_plot.png")
plt.savefig(plot_path, bbox_extra_artists=(lgd,), bbox_inches="tight", transparent=True)
if config.VERBOSE:
plt.show(block=False)
plt.pause(10)
logger.debug(f"Closing cluster plot, it has been saved to {plot_path}.")
plt.close()
def export_xvectors(self):
logger.info("Exporting SpeechBrain embeddings...")
os.makedirs(self.split_directory, exist_ok=True)
arguments = [
ExportIvectorsArguments(
j.id,
getattr(self, "session" if config.USE_THREADING else "db_string", ""),
j.construct_path(self.working_log_directory, "export_ivectors", "log"),
self.use_xvector,
)
for j in self.jobs
]
utterance_mapping = []
for utt_id, ark_path in run_kaldi_function(
ExportIvectorsFunction, arguments, total_count=self.num_utterances
):
utterance_mapping.append({"id": utt_id, "ivector_ark": ark_path})
with self.session() as session:
bulk_update(session, Utterance, utterance_mapping)
session.commit()
self._write_ivectors()
def fix_speaker_ordering(self):
with self.session() as session:
query = (
session.query(Speaker.id, File.id)
.join(Utterance.speaker)
.join(Utterance.file)
.distinct()
)
speaker_ordering_mapping = []
for s_id, f_id in query:
speaker_ordering_mapping.append({"speaker_id": s_id, "file_id": f_id, "index": 10})
session.execute(sqlalchemy.delete(SpeakerOrdering))
session.flush()
session.execute(
sqlalchemy.dialects.postgresql.insert(SpeakerOrdering)
.values(speaker_ordering_mapping)
.on_conflict_do_nothing()
)
session.commit()
def initialize_mfa_clustering(self):
self._unknown_speaker_break_up_count = 0
with self.session() as session:
next_speaker_id = self.get_next_primary_key(Speaker)
speaker_mapping = {}
existing_speakers = {
name: s_id for s_id, name in session.query(Speaker.id, Speaker.name)
}
utterance_mapping = []
self.classification_score = 0
unk_count = 0
if self.use_xvector:
arguments = [
SpeechbrainArguments(
j.id,
getattr(self, "session" if config.USE_THREADING else "db_string", ""),
None,
self.cuda,
self.cluster,
)
for j in self.jobs
]
func = SpeechbrainClassificationFunction
score_threshold = self.initial_sb_score_threshold
self.export_xvectors()
else:
self.plda = read_kaldi_object(Plda, self.plda_path)
arguments = self.plda_classification_arguments()
func = PldaClassificationFunction
score_threshold = self.initial_plda_score_threshold
logger.info("Generating initial speaker labels...")
utt2spk = {k: v for k, v in session.query(Utterance.id, Utterance.speaker_id)}
for utt_id, classified_speaker, score in run_kaldi_function(
func, arguments, total_count=self.num_utterances
):
classified_speaker = str(classified_speaker)
self.classification_score += score / self.num_utterances
if score < score_threshold:
unk_count += 1
utterance_mapping.append(
{"id": utt_id, "speaker_id": existing_speakers["MFA_UNKNOWN"]}
)
continue
if classified_speaker in existing_speakers:
speaker_id = existing_speakers[classified_speaker]
else:
if classified_speaker not in speaker_mapping:
speaker_mapping[classified_speaker] = {
"id": next_speaker_id,
"name": classified_speaker,
}
next_speaker_id += 1
speaker_id = speaker_mapping[classified_speaker]["id"]
if speaker_id == utt2spk[utt_id]:
continue
utterance_mapping.append({"id": utt_id, "speaker_id": speaker_id})
if speaker_mapping:
session.bulk_insert_mappings(Speaker, list(speaker_mapping.values()))
session.flush()
session.commit()
session.execute(sqlalchemy.text("DROP INDEX IF EXISTS ix_utterance_speaker_id"))
session.execute(sqlalchemy.text("DROP INDEX IF EXISTS utterance_position_index"))
session.commit()
bulk_update(session, Utterance, utterance_mapping)
session.execute(
sqlalchemy.text(
"CREATE INDEX IF NOT EXISTS ix_utterance_speaker_id on utterance(speaker_id)"
)
)
session.execute(
sqlalchemy.text(
'CREATE INDEX IF NOT EXISTS utterance_position_index on utterance(file_id, speaker_id, begin, "end", channel)'
)
)
session.commit()
self.breakup_large_clusters()
self.cleanup_empty_speakers()
def export_speaker_ivectors(self):
logger.info("Exporting current speaker ivectors...")
with self.session() as session, tqdm(
total=self.num_speakers, disable=config.QUIET
) as pbar, mfa_open(self.num_utts_path, "w") as f:
if self.use_xvector:
ivector_column = Speaker.xvector
else:
ivector_column = Speaker.ivector
speakers = (
session.query(Speaker.id, ivector_column, sqlalchemy.func.count(Utterance.id))
.join(Speaker.utterances)
.filter(Speaker.name != "MFA_UNKNOWN")
.group_by(Speaker.id)
.order_by(Speaker.id)
)
input_proc = subprocess.Popen(
[
thirdparty_binary("copy-vector"),
"--binary=true",
"ark,t:-",
f"ark:{self.speaker_ivector_path}",
],
stdin=subprocess.PIPE,
stderr=subprocess.DEVNULL,
env=os.environ,
)
for s_id, ivector, utterance_count in speakers:
if ivector is None:
continue
ivector = " ".join([format(x, ".12g") for x in ivector])
in_line = f"{s_id} [ {ivector} ]\n".encode("utf8")
input_proc.stdin.write(in_line)
input_proc.stdin.flush()
pbar.update(1)
f.write(f"{s_id} {utterance_count}\n")
input_proc.stdin.close()
input_proc.wait()
def classify_iteration(self, iteration=None) -> None:
logger.info("Classifying utterances...")
low_count = None
if iteration is not None and self.min_cluster_size:
low_count = np.linspace(0, self.min_cluster_size, self.max_iterations)[iteration]
logger.debug(f"Minimum size: {low_count}")
score_threshold = self.plda_score_threshold
if iteration is not None:
score_threshold = np.linspace(
self.initial_plda_score_threshold,
self.plda_score_threshold,
self.max_iterations,
)[iteration]
logger.debug(f"Score threshold: {score_threshold}")
with self.session() as session:
unknown_speaker_id = (
session.query(Speaker.id).filter(Speaker.name == "MFA_UNKNOWN").first()[0]
)
utterance_mapping = []
self.classification_score = 0
self.plda = read_kaldi_object(Plda, self.plda_path)
arguments = self.plda_classification_arguments()
func = PldaClassificationFunction
utt2spk = {k: v for k, v in session.query(Utterance.id, Utterance.speaker_id)}
for utt_id, classified_speaker, score in run_kaldi_function(
func, arguments, total_count=self.num_utterances
):
self.classification_score += score / self.num_utterances
if score < score_threshold:
speaker_id = unknown_speaker_id
else:
speaker_id = classified_speaker
if speaker_id == utt2spk[utt_id]:
continue
utterance_mapping.append({"id": utt_id, "speaker_id": speaker_id})
logger.debug(f"Updating {len(utterance_mapping)} utterances with new speakers")
session.commit()
session.execute(sqlalchemy.text("DROP INDEX IF EXISTS ix_utterance_speaker_id"))
session.execute(sqlalchemy.text("DROP INDEX IF EXISTS utterance_position_index"))
session.commit()
bulk_update(session, Utterance, utterance_mapping)
session.execute(
sqlalchemy.text(
"CREATE INDEX IF NOT EXISTS ix_utterance_speaker_id on utterance(speaker_id)"
)
)
session.execute(
sqlalchemy.text(
'CREATE INDEX IF NOT EXISTS utterance_position_index on utterance(file_id, speaker_id, begin, "end", channel)'
)
)
session.commit()
if iteration is not None and iteration < self.max_iterations - 2:
self.breakup_large_clusters()
self.cleanup_empty_speakers(low_count)
def breakup_large_clusters(self):
with self.session() as session:
unknown_speaker_id = (
session.query(Speaker.id).filter(Speaker.name == "MFA_UNKNOWN").first()[0]
)
sq = (
session.query(Speaker.id, sqlalchemy.func.count().label("utterance_count"))
.join(Speaker.utterances)
.filter(Speaker.id != unknown_speaker_id)
.group_by(Speaker.id)
)
above_threshold_speakers = [unknown_speaker_id]
threshold = 500
for s_id, utterance_count in sq:
if threshold and utterance_count > threshold and s_id not in self.single_clusters:
above_threshold_speakers.append(s_id)
logger.info("Breaking up large speakers...")
logger.debug(f"Unknown speaker is {unknown_speaker_id}")
next_speaker_id = self.get_next_primary_key(Speaker)
with tqdm(total=len(above_threshold_speakers), disable=config.QUIET) as pbar:
utterance_mapping = []
new_speakers = {}
for s_id in above_threshold_speakers:
logger.debug(f"Breaking up {s_id}")
query = session.query(Utterance.id, Utterance.plda_vector).filter(
Utterance.plda_vector != None, Utterance.speaker_id == s_id # noqa
)
pbar.update(1)
ivectors = np.empty((query.count(), PLDA_DIMENSION))
logger.debug(f"Had {ivectors.shape[0]} utterances.")
if ivectors.shape[0] == 0:
continue
utterance_ids = []
for i, (u_id, ivector) in enumerate(query):
if self.stopped.is_set():
break
utterance_ids.append(u_id)
ivectors[i, :] = ivector
if ivectors.shape[0] < self.min_cluster_size:
continue
labels = cluster_matrix(
ivectors,
ClusterType.optics,
metric=DistanceMetric.cosine,
strict=False,
no_visuals=True,
working_directory=self.working_directory,
distance_threshold=0.25,
)
unique, counts = np.unique(labels, return_counts=True)
num_clusters = unique.shape[0]
counts = dict(zip(unique, counts))
logger.debug(f"{num_clusters} clusters found: {counts}")
if num_clusters == 1:
if s_id != unknown_speaker_id:
logger.debug(f"Deleting {s_id} due to no clusters found")
session.execute(
sqlalchemy.update(Utterance)
.filter(Utterance.speaker_id == s_id)
.values({Utterance.speaker_id: unknown_speaker_id})
)
session.flush()
continue
if num_clusters == 2:
if s_id != unknown_speaker_id:
logger.debug(
f"Only found one cluster for {s_id} will skip in the future"
)
self.single_clusters.add(s_id)
continue
for i, utt_id in enumerate(utterance_ids):
label = labels[i]
if label == -1:
speaker_id = unknown_speaker_id
else:
if s_id in self.single_clusters:
continue
if label not in new_speakers:
if s_id == unknown_speaker_id:
label = self._unknown_speaker_break_up_count
self._unknown_speaker_break_up_count += 1
new_speakers[label] = {
"id": next_speaker_id,
"name": f"{s_id}_{label}",
}
next_speaker_id += 1
speaker_id = new_speakers[label]["id"]
utterance_mapping.append({"id": utt_id, "speaker_id": speaker_id})
if new_speakers:
session.bulk_insert_mappings(Speaker, list(new_speakers.values()))
session.commit()
if utterance_mapping:
bulk_update(session, Utterance, utterance_mapping)
session.commit()
logger.debug(f"Broke speakers into {len(new_speakers)} new speakers.")
def cleanup_empty_speakers(self, threshold=None):
begin = time.time()
with self.session() as session:
session.execute(sqlalchemy.delete(SpeakerOrdering))
session.flush()
unknown_speaker_id = (
session.query(Speaker.id).filter(Speaker.name == "MFA_UNKNOWN").first()
)
if unknown_speaker_id is None:
unknown_speaker_id = -1
else:
unknown_speaker_id = unknown_speaker_id[0]
non_empty_speakers = [unknown_speaker_id]
sq = (
session.query(
Speaker.id, sqlalchemy.func.count(Utterance.id).label("utterance_count")
)
.outerjoin(Speaker.utterances)
.filter(Speaker.id != unknown_speaker_id)
.group_by(Speaker.id)
)
below_threshold_speakers = []
for s_id, utterance_count in sq:
if threshold and utterance_count < threshold:
below_threshold_speakers.append(s_id)
continue
elif utterance_count == 0:
below_threshold_speakers.append(s_id)
continue
non_empty_speakers.append(s_id)
session.execute(
sqlalchemy.update(Utterance)
.where(Utterance.speaker_id.in_(below_threshold_speakers))
.values(speaker_id=unknown_speaker_id)
)
session.execute(sqlalchemy.delete(Speaker).where(~Speaker.id.in_(non_empty_speakers)))
session.commit()
self._num_speakers = session.query(Speaker).count()
conn = self.db_engine.connect()
try:
conn.execution_options(isolation_level="AUTOCOMMIT")
conn.execute(
sqlalchemy.text(
f"VACUUM ANALYZE {Speaker.__tablename__}, {Utterance.__tablename__}"
)
)
finally:
conn.close()
logger.debug(
f"Cleaning up {len(below_threshold_speakers)} empty speakers took {time.time() - begin:.3f} seconds."
)
[docs]
def cluster_utterances_mfa(self) -> None:
"""
Cluster utterances with a ivector or speechbrain model
"""
self.cluster = False
self.setup()
with self.session() as session:
if session.query(Speaker).filter(Speaker.name == "MFA_UNKNOWN").first() is None:
session.add(Speaker(id=self.get_next_primary_key(Speaker), name="MFA_UNKNOWN"))
session.commit()
self.initialize_mfa_clustering()
with self.session() as session:
uncategorized_count = (
session.query(Utterance)
.join(Utterance.speaker)
.filter(Speaker.name == "MFA_UNKNOWN")
.count()
)
if self.use_xvector:
logger.info(f"Initial average cosine score {self.classification_score:.4f}")
else:
logger.info(f"Initial average PLDA score {self.classification_score:.4f}")
logger.info(f"Number of speakers: {self.num_speakers}")
logger.info(f"Unclassified utterances: {uncategorized_count}")
self._unknown_speaker_break_up_count = 0
for i in range(self.max_iterations):
logger.info(f"Iteration {i}:")
current_score = self.classification_score
self._write_ivectors()
self.compute_plda()
self.refresh_plda_vectors()
self.refresh_speaker_vectors()
self.export_speaker_ivectors()
self.classify_iteration(i)
improvement = self.classification_score - current_score
with self.session() as session:
uncategorized_count = (
session.query(Utterance)
.join(Utterance.speaker)
.filter(Speaker.name == "MFA_UNKNOWN")
.count()
)
logger.info(f"Average PLDA score {self.classification_score:.4f}")
logger.info(f"Improvement: {improvement:.4f}")
logger.info(f"Number of speakers: {self.num_speakers}")
logger.info(f"Unclassified utterances: {uncategorized_count}")
logger.debug(f"Found {self.num_speakers} clusters")
if config.DEBUG and self.num_utterances < 100000:
self.visualize_current_clusters()
def visualize_current_clusters(self):
with self.session() as session:
query = (
session.query(Speaker.name, Utterance.plda_vector)
.join(Utterance.speaker)
.filter(Utterance.plda_vector is not None)
)
dim = PLDA_DIMENSION
num_utterances = query.count()
if num_utterances == 0:
if self.use_xvector:
column = Utterance.xvector
dim = XVECTOR_DIMENSION
else:
column = Utterance.ivector
dim = IVECTOR_DIMENSION
query = (
session.query(Speaker.name, column)
.join(Utterance.speaker)
.filter(column is not None)
)
num_utterances = query.count()
if num_utterances == 0:
logger.warning("No ivectors/xvectors to visualize")
return
ivectors = np.empty((query.count(), dim))
labels = []
for s_name, ivector in query:
ivectors[len(labels), :] = ivector
labels.append(s_name)
self.visualize_clusters(ivectors, labels)
[docs]
def cluster_utterances(self) -> None:
"""
Cluster utterances with a ivector or speechbrain model
"""
if self.cluster_type is ClusterType.mfa:
self.cluster_utterances_mfa()
self.fix_speaker_ordering()
if not self.evaluation_mode:
self.cleanup_empty_speakers()
self.refresh_speaker_vectors()
if self.evaluation_mode:
self.evaluate_clustering()
return
self.setup()
os.environ["OMP_NUM_THREADS"] = f"{config.NUM_JOBS}"
os.environ["OPENBLAS_NUM_THREADS"] = f"{config.NUM_JOBS}"
os.environ["MKL_NUM_THREADS"] = f"{config.NUM_JOBS}"
if self.metric is DistanceMetric.plda:
self.plda = read_kaldi_object(Plda, self.plda_path)
if self.evaluation_mode and config.DEBUG:
self.calculate_eer()
logger.info("Clustering utterances (this may take a while, please be patient)...")
with self.session() as session:
if self.use_pca:
query = session.query(Utterance.id, Utterance.plda_vector).filter(
Utterance.plda_vector != None # noqa
)
ivectors = np.empty((query.count(), PLDA_DIMENSION))
elif self.use_xvector:
query = session.query(Utterance.id, Utterance.xvector).filter(
Utterance.xvector != None # noqa
)
ivectors = np.empty((query.count(), XVECTOR_DIMENSION))
else:
query = session.query(Utterance.id, Utterance.ivector).filter(
Utterance.ivector != None # noqa
)
ivectors = np.empty((query.count(), IVECTOR_DIMENSION))
utterance_ids = []
for i, (u_id, ivector) in enumerate(query):
if self.stopped.is_set():
break
utterance_ids.append(u_id)
ivectors[i, :] = ivector
num_utterances = ivectors.shape[0]
kwargs = {}
if self.stopped.is_set():
logger.debug("Stopping clustering early.")
return
kwargs["min_cluster_size"] = self.min_cluster_size
kwargs["distance_threshold"] = self.distance_threshold
if self.cluster_type is ClusterType.agglomerative:
kwargs["memory"] = MEMORY
kwargs["linkage"] = self.linkage
kwargs["n_clusters"] = self.expected_num_speakers
if not self.expected_num_speakers:
kwargs["n_clusters"] = None
elif self.cluster_type is ClusterType.spectral:
kwargs["n_clusters"] = self.expected_num_speakers
elif self.cluster_type is ClusterType.hdbscan:
kwargs["memory"] = MEMORY
elif self.cluster_type is ClusterType.optics:
kwargs["memory"] = MEMORY
elif self.cluster_type is ClusterType.kmeans:
kwargs["n_clusters"] = self.expected_num_speakers
labels = cluster_matrix(
ivectors,
self.cluster_type,
metric=self.metric,
plda=self.plda,
working_directory=self.working_directory,
**kwargs,
)
if self.stopped.is_set():
logger.debug("Stopping clustering early.")
return
if config.DEBUG:
self.visualize_clusters(ivectors, labels)
utterance_clusters = collections.defaultdict(list)
for i in range(num_utterances):
u_id = utterance_ids[i]
cluster_id = int(labels[i])
utterance_clusters[cluster_id].append(u_id)
utterance_mapping = []
next_speaker_id = self.get_next_primary_key(Speaker)
speaker_mapping = []
unknown_speaker_id = None
for cluster_id, utterance_ids in sorted(utterance_clusters.items()):
if cluster_id < 0:
if unknown_speaker_id is None:
speaker_name = "MFA_UNKNOWN"
speaker_mapping.append({"id": next_speaker_id, "name": speaker_name})
speaker_id = next_speaker_id
unknown_speaker_id = speaker_id
next_speaker_id += 1
else:
speaker_id = unknown_speaker_id
else:
speaker_name = f"Cluster {cluster_id}"
speaker_mapping.append({"id": next_speaker_id, "name": speaker_name})
speaker_id = next_speaker_id
next_speaker_id += 1
for u_id in utterance_ids:
utterance_mapping.append({"id": u_id, "speaker_id": speaker_id})
if self.stopped.is_set():
logger.debug("Stopping clustering early.")
return
if speaker_mapping:
session.bulk_insert_mappings(Speaker, speaker_mapping)
session.flush()
session.commit()
bulk_update(session, Utterance, utterance_mapping)
session.flush()
session.commit()
if not self.evaluation_mode:
self.clean_up_unknown_speaker()
self.fix_speaker_ordering()
if not self.evaluation_mode:
self.cleanup_empty_speakers()
self.refresh_speaker_vectors()
if self.evaluation_mode:
self.evaluate_clustering()
os.environ["OMP_NUM_THREADS"] = f"{config.BLAS_NUM_THREADS}"
os.environ["OPENBLAS_NUM_THREADS"] = f"{config.BLAS_NUM_THREADS}"
os.environ["MKL_NUM_THREADS"] = f"{config.BLAS_NUM_THREADS}"
def clean_up_unknown_speaker(self):
with self.session() as session:
unknown_speaker = session.query(Speaker).filter(Speaker.name == "MFA_UNKNOWN").first()
next_speaker_id = self.get_next_primary_key(Speaker)
if unknown_speaker is not None:
speaker_mapping = {}
utterance_mapping = []
query = (
session.query(File.id, File.name)
.join(File.utterances)
.filter(Utterance.speaker_id == unknown_speaker.id)
.distinct()
)
for file_id, file_name in query:
speaker_mapping[file_id] = {"id": next_speaker_id, "name": file_name}
next_speaker_id += 1
query = (
session.query(Utterance.id, Utterance.file_id)
.join(File.utterances)
.filter(Utterance.speaker_id == unknown_speaker.id)
)
for utterance_id, file_id in query:
utterance_mapping.append(
{"id": utterance_id, "speaker_id": speaker_mapping[file_id]["id"]}
)
session.bulk_insert_mappings(Speaker, list(speaker_mapping.values()))
session.flush()
session.execute(
sqlalchemy.delete(SpeakerOrdering).where(
SpeakerOrdering.c.speaker_id == unknown_speaker.id
)
)
session.commit()
bulk_update(session, Utterance, utterance_mapping)
session.commit()
[docs]
def calculate_eer(self) -> typing.Tuple[float, float]:
"""
Calculate Equal Error Rate (EER) and threshold for the diarization metric using the ground truth data.
Returns
-------
float
EER
float
Threshold of EER
"""
if not FOUND_SPEECHBRAIN:
logger.info("No speechbrain found, skipping EER calculation.")
return 0.0, 0.0
logger.info("Calculating EER using ground truth speakers...")
limit_per_speaker = 5
limit_within_speaker = 30
begin = time.time()
arguments = [
ComputeEerArguments(
j.id,
getattr(self, "session" if config.USE_THREADING else "db_string", ""),
None,
self.plda,
self.metric,
self.use_xvector,
limit_within_speaker,
limit_per_speaker,
)
for j in self.jobs
]
match_scores = []
mismatch_scores = []
for result in run_kaldi_function(
ComputeEerFunction, arguments, total_count=self.num_speakers
):
matches, mismatches = result
match_scores.extend(matches)
mismatch_scores.extend(mismatches)
random.shuffle(mismatches)
mismatch_scores = mismatch_scores[: len(match_scores)]
match_scores = np.array(match_scores)
mismatch_scores = np.array(mismatch_scores)
device = torch.device("cuda" if self.cuda else "cpu")
eer, thresh = EER(
torch.tensor(mismatch_scores, device=device),
torch.tensor(match_scores, device=device),
)
logger.debug(
f"Matching scores: {np.min(match_scores):.3f}-{np.max(match_scores):.3f} (mean = {match_scores.mean():.3f}, n = {match_scores.shape[0]})"
)
logger.debug(
f"Mismatching scores: {np.min(mismatch_scores):.3f}-{np.max(mismatch_scores):.3f} (mean = {mismatch_scores.mean():.3f}, n = {mismatch_scores.shape[0]})"
)
logger.info(f"EER: {eer*100:.2f}%")
logger.info(f"Threshold: {thresh:.4f}")
logger.debug(f"Calculating EER took {time.time() - begin:.3f} seconds")
return eer, thresh
[docs]
def load_embeddings(self) -> None:
"""Load embeddings from a speechbrain model"""
logger.info("Loading SpeechBrain embeddings...")
with self.session() as session:
begin = time.time()
update_mapping = {}
arguments = [SpeechbrainArguments(1, self.session, None, self.cuda, self.cluster)]
utterance_ids = []
original_use_mp = config.USE_MP
if self.cuda:
config.update_configuration(
{
"USE_MP": False,
}
)
batch_size = 100000
for u_id, emb in run_kaldi_function(
SpeechbrainEmbeddingFunction, arguments, total_count=self.num_utterances
):
utterance_ids.append(u_id)
update_mapping[u_id] = {"id": u_id, "xvector": emb}
if len(update_mapping) >= batch_size:
bulk_update(session, Utterance, list(update_mapping.values()))
session.commit()
update_mapping = {}
if len(update_mapping):
bulk_update(session, Utterance, list(update_mapping.values()))
session.commit()
if len(utterance_ids):
speaker_ids = (
session.query(Speaker.id)
.join(Utterance.speaker)
.filter(Utterance.id.in_(utterance_ids))
.distinct()
)
session.query(Speaker).filter(Speaker.id.in_(speaker_ids)).update(
{Speaker.modified: True}
)
session.commit()
n_lists = int(math.sqrt(self.num_utterances))
n_probe = int(math.sqrt(n_lists))
session.execute(
sqlalchemy.text(
f"CREATE INDEX IF NOT EXISTS utterance_xvector_index ON utterance USING ivfflat (xvector vector_cosine_ops) WITH (lists = {n_lists});"
)
)
session.execute(sqlalchemy.text(f"SET ivfflat.probes = {n_probe};"))
session.commit()
session.query(Corpus).update({Corpus.xvectors_loaded: True})
session.commit()
if self.cuda:
config.update_configuration(
{
"USE_MP": original_use_mp,
}
)
logger.debug(f"Loading embeddings took {time.time() - begin:.3f} seconds")
def refresh_plda_vectors(self):
begin = time.time()
self.create_new_current_workflow(WorkflowType.speaker_diarization)
self.plda = read_kaldi_object(Plda, self.plda_path)
with self.session() as session:
logger.info("Refreshing utterance PLDA vectors...")
session.execute(sqlalchemy.text("DROP INDEX IF EXISTS utterance_plda_vector_index;"))
session.commit()
c = session.query(Corpus).first()
c.plda_calculated = False
update_mapping = []
utterances = session.query(Utterance.id, c.utterance_ivector_column).filter(
c.utterance_ivector_column != None # noqa
)
with tqdm(total=self.num_utterances, disable=config.QUIET) as pbar:
for utt_id, ivector in utterances:
kaldi_ivector = FloatVector()
kaldi_ivector.from_numpy(ivector)
pbar.update(1)
update_mapping.append(
{
"id": utt_id,
"plda_vector": self.plda.transform_ivector(kaldi_ivector, 1).numpy(),
}
)
bulk_update(session, Utterance, update_mapping)
c.plda_calculated = True
n_lists = int(math.sqrt(self.num_utterances))
n_probe = int(math.sqrt(n_lists))
session.execute(
sqlalchemy.text(
f"CREATE INDEX IF NOT EXISTS utterance_plda_vector_index ON utterance USING ivfflat (plda_vector vector_cosine_ops) WITH (lists = {n_lists});"
)
)
session.execute(sqlalchemy.text(f"SET ivfflat.probes = {n_probe};"))
session.commit()
autocommit_engine = self.db_engine.execution_options(isolation_level="AUTOCOMMIT")
with sqlalchemy.orm.Session(autocommit_engine) as session:
session.execute(sqlalchemy.text("VACUUM ANALYZE Utterance"))
logger.debug(f"Refreshing utterance PLDA vectors took {time.time() - begin:.3f} seconds.")
[docs]
def refresh_speaker_vectors(self) -> None:
"""Refresh speaker vectors following clustering or classification"""
begin = time.time()
logger.info("Refreshing speaker vectors...")
with self.session() as session, tqdm(
total=self.num_speakers, disable=config.QUIET
) as pbar:
if self.use_xvector:
ivector_column = Utterance.xvector
else:
ivector_column = Utterance.ivector
update_mapping = {}
speaker_ids = []
speakers = session.query(Speaker.id)
for (s_id,) in speakers:
query = session.query(ivector_column).filter(Utterance.speaker_id == s_id)
s_ivectors = []
for (u_ivector,) in query:
s_ivectors.append(u_ivector)
if not s_ivectors:
continue
mean_ivector = np.mean(np.array(s_ivectors), axis=0)
speaker_mean = DoubleVector()
speaker_mean.from_numpy(mean_ivector)
ivector_normalize_length(speaker_mean)
speaker_ids.append(s_id)
if self.use_xvector:
key = "xvector"
else:
key = "ivector"
update_mapping[s_id] = {"id": s_id, key: speaker_mean.numpy()}
if self.plda is not None:
update_mapping[s_id]["plda_vector"] = self.plda.transform_ivector(
speaker_mean, len(s_ivectors)
).numpy()
pbar.update(1)
bulk_update(session, Speaker, list(update_mapping.values()))
session.commit()
autocommit_engine = self.db_engine.execution_options(isolation_level="AUTOCOMMIT")
with sqlalchemy.orm.Session(autocommit_engine) as session:
session.execute(sqlalchemy.text("VACUUM ANALYZE Speaker"))
if self.use_xvector:
logger.debug(f"Refreshing speaker xvectors took {time.time() - begin:.3f} seconds.")
else:
logger.debug(f"Refreshing speaker ivectors took {time.time() - begin:.3f} seconds.")
[docs]
def compute_speaker_embeddings(self) -> None:
"""Generate per-speaker embeddings as the mean over their utterances"""
if not self.has_xvectors():
self.load_embeddings()
begin = time.time()
logger.info("Computing SpeechBrain speaker embeddings...")
with self.session() as session:
update_mapping = []
speakers = session.query(Speaker.id).filter(
sqlalchemy.or_(Speaker.xvector == None, Speaker.modified == True) # noqa
)
with tqdm(total=speakers.count(), disable=config.QUIET) as pbar:
for (s_id,) in speakers:
u_query = session.query(Utterance.xvector).filter(
Utterance.speaker_id == s_id, Utterance.xvector != None # noqa
)
num_utterances = u_query.count()
if not num_utterances:
continue
embeddings = np.empty((num_utterances, XVECTOR_DIMENSION))
for i, (xvector,) in enumerate(u_query):
embeddings[i, :] = xvector
speaker_xvector = np.mean(embeddings, axis=0)
speaker_mean = DoubleVector()
speaker_mean.from_numpy(speaker_xvector)
ivector_normalize_length(speaker_mean)
update_mapping.append(
{
"id": s_id,
"xvector": speaker_mean.numpy(),
"modified": False,
"num_utterances": num_utterances,
}
)
pbar.update(1)
logger.debug(f"Updating {len(update_mapping)} speakers...")
bulk_update(session, Speaker, update_mapping)
session.commit()
n_lists = int(math.sqrt(self.num_speakers))
session.execute(
sqlalchemy.text(
f"CREATE INDEX IF NOT EXISTS speaker_xvector_index ON speaker USING ivfflat (xvector vector_cosine_ops) WITH (lists = {n_lists});"
)
)
session.commit()
if update_mapping:
autocommit_engine = self.db_engine.execution_options(isolation_level="AUTOCOMMIT")
with sqlalchemy.orm.Session(autocommit_engine) as session:
session.execute(sqlalchemy.text("VACUUM ANALYZE Speaker"))
logger.debug(f"Computing speaker xvectors took {time.time() - begin:.3f} seconds.")
self.cleanup_empty_speakers()
[docs]
def export_files(self, output_directory: str) -> None:
"""
Export files with their new speaker labels
Parameters
----------
output_directory: str
Output directory to save files
"""
if not self.overwrite and os.path.exists(output_directory):
output_directory = self.working_directory.joinpath("speaker_classification")
os.makedirs(output_directory, exist_ok=True)
diagnostic_files = [
"diarization_evaluation_results.csv",
"cluster_plot.png",
"nearest_neighbors.png",
]
for fname in diagnostic_files:
path = self.working_directory.joinpath(fname)
if os.path.exists(path):
shutil.copyfile(
path,
os.path.join(output_directory, fname),
)
with mfa_open(os.path.join(output_directory, "parameters.yaml"), "w") as f:
yaml.dump(
{
"ivector_extractor_path": str(self.ivector_extractor_path),
"expected_num_speakers": self.expected_num_speakers,
"cluster": self.cluster,
"cuda": self.cuda,
"metric": self.metric.name,
"cluster_type": self.cluster_type.name,
"distance_threshold": self.distance_threshold,
"min_cluster_size": self.min_cluster_size,
"linkage": self.linkage,
},
f,
Dumper=yaml.Dumper,
)
with self.session() as session:
logger.info("Writing output files...")
files = session.query(File).options(
selectinload(File.utterances),
selectinload(File.speakers),
joinedload(File.sound_file, innerjoin=True).load_only(SoundFile.duration),
joinedload(File.text_file, innerjoin=True).load_only(TextFile.file_type),
)
with tqdm(total=self.num_files, disable=config.QUIET) as pbar:
for file in files:
utterance_count = len(file.utterances)
if utterance_count == 0:
logger.debug(f"Could not find any utterances for {file.name}")
continue
output_format = file.text_file.file_type
output_path = construct_output_path(
file.name,
file.relative_path,
output_directory,
output_format=output_format,
)
if output_format == "lab":
with mfa_open(output_path, "w") as f:
f.write(file.utterances[0].text)
else:
data = file.construct_transcription_tiers(original_text=True)
export_textgrid(
data,
output_path,
file.duration,
self.export_frame_shift,
output_format,
)
pbar.update(1)