"""
Abstract Base Classes
=====================
"""
from __future__ import annotations
import abc
import contextlib
import logging
import os
import shutil
import subprocess
import sys
import time
import traceback
import typing
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Set,
Tuple,
Type,
Union,
get_type_hints,
)
import sqlalchemy
import yaml
from sqlalchemy.orm import scoped_session, sessionmaker
from montreal_forced_aligner import config
from montreal_forced_aligner.db import CorpusWorkflow, MfaSqlBase
from montreal_forced_aligner.exceptions import (
DatabaseError,
KaldiProcessingError,
MultiprocessingError,
)
from montreal_forced_aligner.helper import comma_join, load_configuration, mfa_open
if TYPE_CHECKING:
from pathlib import Path
from montreal_forced_aligner.data import MfaArguments, WorkflowType
__all__ = [
"MfaModel",
"MfaWorker",
"TopLevelMfaWorker",
"MetaDict",
"DatabaseMixin",
"FileExporterMixin",
"ModelExporterMixin",
"TemporaryDirectoryMixin",
"AdapterMixin",
"TrainerMixin",
"KaldiFunction",
]
# Configuration types
MetaDict = Dict[str, Any]
logger = logging.getLogger("mfa")
[docs]
class KaldiFunction(metaclass=abc.ABCMeta):
"""
Abstract class for running Kaldi functions
"""
def __init__(self, args: MfaArguments):
self.args = args
self.db_string = None
self._session = None
if isinstance(self.args.session, str):
self.db_string = self.args.session
else:
self._session = self.args.session
self.job_name = self.args.job_name
self.log_path = self.args.log_path
self.callback = None
@contextlib.contextmanager
def session(self):
if self._session is not None:
with self._session() as session:
yield session
else:
db_engine = sqlalchemy.create_engine(self.db_string)
with sqlalchemy.orm.Session(db_engine) as session:
yield session
[docs]
def run(self):
"""Run the function, calls subclassed object's ``_run`` with error handling"""
try:
if self._session is not None:
config.USE_THREADING = True
else:
config.USE_THREADING = False
self._run()
except Exception:
exc_type, exc_value, exc_traceback = sys.exc_info()
error_text = "\n".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
raise MultiprocessingError(self.job_name, error_text)
def _run(self) -> None:
"""Internal logic for running the worker"""
pass
[docs]
def check_call(self, proc: subprocess.Popen):
"""
Check whether a subprocess successfully completed
Parameters
----------
proc: subprocess.Popen
Subprocess to check
Raises
------
:class:`~montreal_forced_aligner.exceptions.KaldiProcessingError`
If there was an error running the subprocess
"""
if proc.returncode is None:
proc.wait()
if proc.returncode != 0:
raise KaldiProcessingError([self.log_path])
[docs]
class TemporaryDirectoryMixin(metaclass=abc.ABCMeta):
"""
Abstract mixin class for MFA temporary directories
"""
def __init__(
self,
**kwargs,
):
super().__init__(**kwargs)
self._corpus_output_directory = None
self._dictionary_output_directory = None
self._language_model_output_directory = None
self._acoustic_model_output_directory = None
self._g2p_model_output_directory = None
self._ivector_extractor_output_directory = None
self._current_workflow = None
@property
@abc.abstractmethod
def identifier(self) -> str:
"""Identifier to use in creating the temporary directory"""
...
@property
@abc.abstractmethod
def data_source_identifier(self) -> str:
"""Identifier for the data source (generally the corpus being used)"""
...
@property
@abc.abstractmethod
def output_directory(self) -> Path:
"""Root temporary directory"""
...
[docs]
def clean_working_directory(self) -> None:
"""Clean up previous runs"""
shutil.rmtree(self.output_directory, ignore_errors=True)
@property
def corpus_output_directory(self) -> Path:
"""Temporary directory containing all corpus information"""
if self._corpus_output_directory:
return self._corpus_output_directory
return self.output_directory.joinpath(f"{self.data_source_identifier}")
@corpus_output_directory.setter
def corpus_output_directory(self, directory: Path) -> None:
self._corpus_output_directory = directory
@property
def dictionary_output_directory(self) -> Path:
"""Temporary directory containing all dictionary information"""
if self._dictionary_output_directory:
return self._dictionary_output_directory
return self.output_directory.joinpath("dictionary")
@property
def model_output_directory(self) -> Path:
"""Temporary directory containing all dictionary information"""
return self.output_directory.joinpath("models")
@dictionary_output_directory.setter
def dictionary_output_directory(self, directory: Path) -> None:
self._dictionary_output_directory = directory
@property
def language_model_output_directory(self) -> Path:
"""Temporary directory containing all dictionary information"""
if self._language_model_output_directory:
return self._language_model_output_directory
return self.model_output_directory.joinpath("language_model")
@language_model_output_directory.setter
def language_model_output_directory(self, directory: Path) -> None:
self._language_model_output_directory = directory
@property
def acoustic_model_output_directory(self) -> Path:
"""Temporary directory containing all dictionary information"""
if self._acoustic_model_output_directory:
return self._acoustic_model_output_directory
return self.model_output_directory.joinpath("acoustic_model")
@acoustic_model_output_directory.setter
def acoustic_model_output_directory(self, directory: Path) -> None:
self._acoustic_model_output_directory = directory
[docs]
class DatabaseMixin(TemporaryDirectoryMixin, metaclass=abc.ABCMeta):
"""
Abstract class for mixing in database functionality
"""
def __init__(
self,
**kwargs,
):
super().__init__(**kwargs)
self._db_engine = None
self._db_path = None
self._session = None
self.database_initialized = False
def cleanup_connections(self) -> None:
if getattr(self, "_session", None) is not None:
self._session.remove()
del self._session
self._session = None
if getattr(self, "_db_engine", None) is not None:
self._db_engine.dispose()
del self._db_engine
self._db_engine = None
[docs]
def delete_database(self) -> None:
"""
Reset all schemas
"""
if config.USE_POSTGRES:
MfaSqlBase.metadata.drop_all(self.db_engine)
elif self.db_path.exists():
os.remove(self.db_path)
[docs]
def initialize_database(self) -> None:
"""
Initialize the database with database schema
"""
if self.database_initialized:
return
from montreal_forced_aligner.command_line.utils import check_databases
if config.USE_POSTGRES:
exist_check = True
try:
check_databases(self.identifier)
except Exception:
try:
subprocess.check_call(
[
"createdb",
f"--host={config.database_socket()}",
self.identifier,
],
stderr=subprocess.DEVNULL,
stdout=subprocess.DEVNULL,
)
except Exception:
raise DatabaseError(
f"There was an error connecting to the {config.CURRENT_PROFILE_NAME} MFA database server. "
"Please ensure the server is initialized (mfa server init) or running (mfa server start)"
)
exist_check = False
else:
exist_check = self.db_path.exists()
self.database_initialized = True
if config.CLEAN or getattr(self, "dirty", False):
self.clean_working_directory()
if exist_check:
if config.CLEAN or getattr(self, "dirty", False):
self.delete_database()
else:
return
os.makedirs(self.output_directory, exist_ok=True)
if config.USE_POSTGRES:
with self.db_engine.connect() as conn:
conn.execute(sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS vector"))
conn.execute(sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS pg_trgm"))
conn.execute(sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS pg_stat_statements"))
conn.execute(sqlalchemy.text(f"select setseed({config.SEED/32768})"))
conn.commit()
MfaSqlBase.metadata.create_all(self.db_engine)
@property
def db_engine(self) -> sqlalchemy.engine.Engine:
"""Database engine"""
if self._db_engine is None:
self._db_engine = self.construct_engine()
return self._db_engine
def get_next_primary_key(self, database_table: MfaSqlBase):
with self.session() as session:
pk = session.query(sqlalchemy.func.max(database_table.id)).scalar()
if not pk:
pk = 0
return pk + 1
def create_new_current_workflow(self, workflow_type: WorkflowType, name: str = None):
from montreal_forced_aligner.db import CorpusWorkflow
with self.session() as session:
if not name:
name = workflow_type.name
self._current_workflow = name
session.query(CorpusWorkflow).update({"current": False})
new_workflow = (
session.query(CorpusWorkflow).filter(CorpusWorkflow.name == name).first()
)
if not new_workflow:
new_workflow = CorpusWorkflow(
name=name,
workflow_type=workflow_type,
working_directory=os.path.join(self.output_directory, name),
current=True,
)
log_dir = os.path.join(new_workflow.working_directory, "log")
os.makedirs(log_dir, exist_ok=True)
session.add(new_workflow)
else:
new_workflow.current = True
session.commit()
def set_current_workflow(self, identifier):
from montreal_forced_aligner.db import CorpusWorkflow
with self.session() as session:
session.query(CorpusWorkflow).update({CorpusWorkflow.current: False})
wf = session.query(CorpusWorkflow).filter(CorpusWorkflow.name == identifier).first()
wf.current = True
self._current_workflow = identifier
session.commit()
@property
def current_workflow(self) -> CorpusWorkflow:
from montreal_forced_aligner.db import CorpusWorkflow
with self.session() as session:
wf = (
session.query(CorpusWorkflow)
.filter(CorpusWorkflow.current == True) # noqa
.first()
)
return wf
@property
def db_path(self) -> Path:
"""Connection path for sqlite database"""
return self.output_directory.joinpath(f"{self.identifier}.db")
@property
def db_string(self) -> str:
"""Connection string for the database"""
if config.USE_POSTGRES:
return f"postgresql+psycopg2://@/{self.identifier}?host={config.database_socket()}"
else:
return f"sqlite:///{self.db_path}"
[docs]
def construct_engine(self, **kwargs) -> sqlalchemy.engine.Engine:
"""
Construct a database engine
Parameters
----------
same_thread: bool, optional
Flag for whether to enforce checking access on different threads, defaults to True
read_only: bool, optional
Flag for whether the database engine should be created as read-only, defaults to False
Returns
-------
:class:`~sqlalchemy.engine.Engine`
SqlAlchemy engine
"""
db_string = self.db_string
if not config.USE_POSTGRES:
if kwargs.pop("read_only", False):
db_string += "?mode=ro&nolock=1&uri=true"
kwargs["pool_size"] = config.NUM_JOBS + 10
kwargs["max_overflow"] = config.NUM_JOBS + 10
e = sqlalchemy.create_engine(
db_string,
**kwargs,
)
return e
@property
def session(self) -> sqlalchemy.orm.scoped_session:
"""
Construct database session
Parameters
----------
**kwargs
Keyword arguments to pass to the Session
Returns
-------
:class:`~sqlalchemy.orm.sessionmaker`
SqlAlchemy session
"""
if self._session is None:
self._session = scoped_session(
sessionmaker(bind=self.db_engine, expire_on_commit=False)
)
return self._session
[docs]
class MfaWorker(metaclass=abc.ABCMeta):
"""
Abstract class for MFA workers
Attributes
----------
dirty: bool
Flag for whether an error was encountered in processing
"""
def __init__(
self,
**kwargs,
):
super().__init__(**kwargs)
self.dirty = False
[docs]
@classmethod
def get_configuration_parameters(cls) -> Dict[str, Type]:
"""
Get the types of parameters available to be configured
Returns
-------
dict[str, Type]
Dictionary of parameter names and their types
"""
mapping = {Dict: dict, Tuple: tuple, List: list, Set: set}
configuration_params = {}
for t, ty in get_type_hints(cls.__init__).items():
configuration_params[t] = ty
try:
if ty.__origin__ == Union:
configuration_params[t] = ty.__args__[0]
except AttributeError:
pass
for c in cls.mro():
try:
for t, ty in get_type_hints(c.__init__).items():
configuration_params[t] = ty
try:
if ty.__origin__ == Union:
configuration_params[t] = ty.__args__[0]
except AttributeError:
pass
except AttributeError:
pass
for t, ty in configuration_params.items():
for v in mapping.values():
try:
if ty.__origin__ == v:
configuration_params[t] = v
break
except AttributeError:
break
return configuration_params
@property
def configuration(self) -> MetaDict:
"""Configuration parameters"""
return {
"dirty": self.dirty,
}
@property
@abc.abstractmethod
def working_directory(self) -> Path:
"""Current working directory"""
...
@property
def working_log_directory(self) -> Path:
"""Current working log directory"""
return self.working_directory.joinpath("log")
@property
@abc.abstractmethod
def data_directory(self) -> Path:
"""Data directory"""
...
[docs]
class TopLevelMfaWorker(MfaWorker, TemporaryDirectoryMixin, metaclass=abc.ABCMeta):
"""
Abstract mixin for top-level workers in MFA. This class holds properties about the larger workflow run.
Parameters
----------
num_jobs: int
Number of jobs and processes to use
clean: bool
Flag for whether to remove any old files in the work directory
"""
nullable_fields = [
"punctuation",
"compound_markers",
"clitic_markers",
"quote_markers",
"word_break_markers",
]
def __init__(
self,
**kwargs,
):
kwargs, skipped = type(self).extract_relevant_parameters(kwargs)
super().__init__(**kwargs)
self.initialized = False
self.start_time = time.time()
self.setup_logger()
if skipped:
logger.warning(f"Skipped the following configuration keys: {comma_join(skipped)}")
[docs]
def cleanup_logger(self):
"""Ensure that loggers are cleaned up on delete"""
logger = logging.getLogger("mfa")
handlers = logger.handlers[:]
for handler in handlers:
if isinstance(handler, logging.FileHandler):
handler.close()
logger.removeHandler(handler)
[docs]
def setup(self) -> None:
"""Setup for worker"""
self.check_previous_run()
if hasattr(self, "initialize_database"):
self.initialize_database()
if hasattr(self, "inspect_database"):
self.inspect_database()
@property
def working_directory(self) -> Path:
"""Alias for a folder that contains worker information, separate from the data directory"""
return self.output_directory.joinpath(self._current_workflow)
[docs]
@classmethod
def parse_args(
cls, args: Optional[Dict[str, Any]], unknown_args: Optional[List[str]]
) -> MetaDict:
"""
Class method for parsing configuration parameters from command line arguments
Parameters
----------
args: dict[str, Any]
Parsed arguments
unknown_args: list[str]
Optional list of arguments that were not parsed
Returns
-------
dict[str, Any]
Dictionary of specified configuration parameters
"""
param_types = cls.get_configuration_parameters()
params = {}
unknown_dict = {}
if unknown_args:
for i, a in enumerate(unknown_args):
if not a.startswith("--"):
continue
name = a.replace("--", "")
if name not in param_types:
continue
if i == len(unknown_args) - 1 or unknown_args[i + 1].startswith("--"):
val = True
else:
val = unknown_args[i + 1]
unknown_dict[name] = val
for name, param_type in param_types.items():
if (name.endswith("_directory") and name != "audio_directory") or (
name.endswith("_path") and name not in {"rules_path", "phone_groups_path"}
):
continue
if args is not None and name in args and args[name] is not None:
params[name] = param_type(args[name])
elif name in unknown_dict:
params[name] = param_type(unknown_dict[name])
if param_type == bool and not isinstance(unknown_dict[name], bool):
if unknown_dict[name].lower() == "false":
params[name] = False
return params
[docs]
@classmethod
def parse_parameters(
cls,
config_path: Optional[Path] = None,
args: Optional[Dict[str, Any]] = None,
unknown_args: Optional[typing.Iterable[str]] = None,
) -> MetaDict:
"""
Parse configuration parameters from a config file and command line arguments
Parameters
----------
config_path: :class:`~pathlib.Path`, optional
Path to yaml configuration file
args: dict[str, Any]
Parsed arguments
unknown_args: list[str]
Optional list of arguments that were not parsed
Returns
-------
dict[str, Any]
Dictionary of specified configuration parameters
"""
global_params = {}
if config_path and os.path.exists(config_path):
data = load_configuration(config_path)
for k, v in data.items():
if v is None and k in cls.nullable_fields:
v = []
global_params[k] = v
global_params.update(cls.parse_args(args, unknown_args))
return global_params
@property
def worker_config_path(self) -> str:
"""Path to worker's configuration in the working directory"""
return os.path.join(self.output_directory, f"{self.data_source_identifier}.yaml")
[docs]
def cleanup(self) -> None:
"""
Clean up loggers and output final message for top-level workers
"""
try:
if hasattr(self, "cleanup_connections"):
self.cleanup_connections()
if self.dirty:
logger.error("There was an error in the run, please see the log.")
else:
logger.info(f"Done! Everything took {time.time() - self.start_time:.3f} seconds")
self.save_worker_config()
self.cleanup_logger()
except (NameError, ValueError): # already cleaned up
pass
[docs]
def save_worker_config(self) -> None:
"""Export worker configuration to its working directory"""
if not os.path.exists(self.output_directory):
return
with mfa_open(self.worker_config_path, "w") as f:
yaml.dump(self.configuration, f)
def _validate_previous_configuration(self, conf: MetaDict) -> None:
"""
Validate the current configuration against a previous configuration
Parameters
----------
conf: dict[str, Any]
Previous run's configuration
"""
from montreal_forced_aligner.utils import get_mfa_version
self.dirty = False
current_version = get_mfa_version()
if not config.DEBUG and conf.get("version", current_version) != current_version:
logger.debug(
f"Previous run was on {conf['version']} version (new run: {current_version})"
)
self.dirty = True
[docs]
def check_previous_run(self) -> None:
"""
Check whether a previous run has any conflicting settings with the current run.
Returns
-------
bool
Flag for whether the current run is compatible with the previous one
"""
if not os.path.exists(self.worker_config_path):
return True
try:
conf = load_configuration(self.worker_config_path)
self._validate_previous_configuration(conf)
if not config.CLEAN and self.dirty:
logger.warning(
"The previous run had a different configuration than the current, which may cause issues."
" Please see the log for details or use --clean flag if issues are encountered."
)
except yaml.error.YAMLError:
logger.warning("The previous run's configuration could not be loaded.")
return False
@property
def identifier(self) -> str:
"""Combined identifier of the data source and workflow"""
return self.data_source_identifier
@property
def output_directory(self) -> Path:
"""Root temporary directory to store all of this worker's files"""
return config.TEMPORARY_DIRECTORY.joinpath(self.identifier)
@property
def log_file(self) -> Path:
"""Path to the worker's log file"""
return self.output_directory.joinpath(f"{self.data_source_identifier}.log")
[docs]
def setup_logger(self) -> None:
"""
Construct a logger for a command line run
"""
from montreal_forced_aligner.helper import configure_logger
from montreal_forced_aligner.utils import get_mfa_version
current_version = get_mfa_version()
# Remove previous directory if versions are different
clean = False
if os.path.exists(self.worker_config_path):
conf = load_configuration(self.worker_config_path)
if conf.get("version", current_version) != current_version:
clean = True
os.makedirs(self.output_directory, exist_ok=True)
configure_logger("mfa", log_file=self.log_file)
logger = logging.getLogger("mfa")
logger.debug(f"Beginning run for {self.data_source_identifier}")
logger.debug(f'Using "{config.CURRENT_PROFILE_NAME}" profile')
if config.USE_MP:
logger.debug(f"Using multiprocessing with {config.NUM_JOBS}")
else:
logger.debug(f"NOT using multiprocessing with {config.NUM_JOBS}")
logger.debug(f"Set up logger for MFA version: {current_version}")
if clean or config.CLEAN:
logger.debug("Cleaned previous run")
[docs]
class ExporterMixin(metaclass=abc.ABCMeta):
"""
Abstract mixin class for exporting any kind of file
Parameters
----------
overwrite: bool
Flag for whether to overwrite the specified path if a file exists
"""
def __init__(self, overwrite: bool = False, **kwargs):
self.overwrite = overwrite
super().__init__(**kwargs)
[docs]
class ModelExporterMixin(ExporterMixin, metaclass=abc.ABCMeta):
"""
Abstract mixin class for exporting MFA models
"""
@property
@abc.abstractmethod
def meta(self) -> MetaDict:
"""Training configuration parameters"""
...
[docs]
@abc.abstractmethod
def export_model(self, output_model_path: Path) -> None:
"""
Abstract method to export an MFA model
Parameters
----------
output_model_path: :class:`~pathlib.Path`
Path to export model
"""
...
[docs]
class FileExporterMixin(ExporterMixin, metaclass=abc.ABCMeta):
"""
Abstract mixin class for exporting TextGrid and text files
Parameters
----------
cleanup_textgrids: bool
Flag for whether to clean up exported TextGrids
"""
[docs]
@abc.abstractmethod
def export_files(self, output_directory: str) -> None:
"""
Export files to an output directory
Parameters
----------
output_directory: str
Directory to export to
"""
...
[docs]
class TrainerMixin(ModelExporterMixin):
"""
Abstract mixin class for MFA trainers
Parameters
----------
num_iterations: int
Number of training iterations
model_version: str
Override for model version
Attributes
----------
iteration: int
Current iteration
"""
def __init__(self, num_iterations: int = 40, model_version: str = None, **kwargs):
super().__init__(**kwargs)
self.iteration: int = 0
self.num_iterations = num_iterations
self.model_version = model_version
[docs]
@abc.abstractmethod
def initialize_training(self) -> None:
"""Initialize training"""
...
[docs]
@abc.abstractmethod
def train(self) -> None:
"""Perform training"""
...
[docs]
@abc.abstractmethod
def train_iteration(self) -> None:
"""Run one training iteration"""
...
[docs]
@abc.abstractmethod
def finalize_training(self) -> None:
"""Finalize training"""
...
[docs]
class AdapterMixin(ModelExporterMixin):
"""
Abstract class for MFA model adaptation
"""
[docs]
@abc.abstractmethod
def adapt(self) -> None:
"""Perform adaptation"""
...
[docs]
class MfaModel(abc.ABC):
"""Abstract class for MFA models"""
extensions: List[str]
model_type = "base_model"
[docs]
@classmethod
def pretrained_directory(cls) -> Path:
"""Directory that pretrained models are saved in"""
from .config import get_temporary_directory
path = get_temporary_directory().joinpath("pretrained_models", cls.model_type)
path.mkdir(parents=True, exist_ok=True)
return path
[docs]
@classmethod
def get_available_models(cls) -> List[str]:
"""
Get a list of available models for a given model type
Returns
-------
list[str]
List of model names
"""
if not cls.pretrained_directory().exists():
return []
available = []
for f in cls.pretrained_directory().iterdir():
if cls.valid_extension(f):
available.append(f.stem)
return available
[docs]
@classmethod
def get_pretrained_path(cls, name: str, enforce_existence: bool = True) -> Path:
"""
Generate a path to a pretrained model based on its name and model type
Parameters
----------
name: str
Name of model
enforce_existence: bool
Flag to return None if the path doesn't exist, defaults to True
Returns
-------
Path
Path to model
"""
return cls.generate_path(cls.pretrained_directory(), name, enforce_existence)
[docs]
@classmethod
@abc.abstractmethod
def valid_extension(cls, filename: Path) -> bool:
"""Check whether a file has a valid extensions"""
...
[docs]
@classmethod
@abc.abstractmethod
def generate_path(
cls, root: Path, name: str, enforce_existence: bool = True
) -> Optional[Path]:
"""Generate a path from a root directory"""
...
[docs]
@abc.abstractmethod
def pretty_print(self) -> None:
"""Print the model's meta data"""
...
@property
@abc.abstractmethod
def meta(self) -> MetaDict:
"""Metadata for the model"""
...