"""
MFA configuration
=================
"""
from __future__ import annotations
import os
import pathlib
import re
import typing
from typing import Any, Dict, List, Union
import click
import dataclassy
import joblib
import yaml
from dataclassy import dataclass
from montreal_forced_aligner.exceptions import RootDirectoryError
from montreal_forced_aligner.helper import mfa_open
__all__ = [
"generate_config_path",
"generate_command_history_path",
"load_command_history",
"get_temporary_directory",
"update_command_history",
"MfaConfiguration",
"GLOBAL_CONFIG",
]
MFA_ROOT_ENVIRONMENT_VARIABLE = "MFA_ROOT_DIR"
MFA_PROFILE_VARIABLE = "MFA_PROFILE"
IVECTOR_DIMENSION = 192
XVECTOR_DIMENSION = 192
PLDA_DIMENSION = 192
[docs]
def get_temporary_directory() -> pathlib.Path:
"""
Get the root temporary directory for MFA
Returns
-------
Path
Root temporary directory
Raises
------
:class:`~montreal_forced_aligner.exceptions.RootDirectoryError`
"""
TEMP_DIR = pathlib.Path(
os.environ.get(MFA_ROOT_ENVIRONMENT_VARIABLE, os.path.expanduser("~/Documents/MFA"))
)
try:
TEMP_DIR.mkdir(parents=True, exist_ok=True)
except OSError:
raise RootDirectoryError(TEMP_DIR, MFA_ROOT_ENVIRONMENT_VARIABLE)
return TEMP_DIR
[docs]
def generate_config_path() -> pathlib.Path:
"""
Generate the global configuration path for MFA
Returns
-------
Path
Full path to configuration yaml
"""
return get_temporary_directory().joinpath("global_config.yaml")
[docs]
def generate_command_history_path() -> pathlib.Path:
"""
Generate the path to the command history file
Returns
-------
Path
Full path to history file
"""
return get_temporary_directory().joinpath("command_history.yaml")
[docs]
def load_command_history() -> List[Dict[str, Any]]:
"""
Load command history for MFA
Returns
-------
list[dict[str, Any]]
List of commands previously run
"""
path = generate_command_history_path()
history = []
if path.exists():
with mfa_open(path, "r") as f:
history = yaml.safe_load(f)
if not history:
history = []
for h in history:
h["command"] = re.sub(r"^\S+.py ", "mfa ", h["command"])
return history
[docs]
def update_command_history(command_data: Dict[str, Any]) -> None:
"""
Update command history with most recent command
Parameters
----------
command_data: dict[str, Any]
Current command metadata
"""
try:
if command_data["command"].split(" ")[1] == "history":
return
except Exception:
return
history = load_command_history()
path = generate_command_history_path()
history.append(command_data)
history = history[-50:]
with mfa_open(path, "w") as f:
yaml.safe_dump(history, f, allow_unicode=True)
[docs]
@dataclass(slots=True)
class MfaProfile:
"""
Configuration class for a profile used from the command line
"""
clean: bool = False
verbose: bool = False
debug: bool = False
quiet: bool = False
overwrite: bool = False
terminal_colors: bool = True
cleanup_textgrids: bool = True
database_backend: str = "psycopg2"
database_port: int = 5433
bytes_limit: int = 100e6
seed: int = 0
num_jobs: int = 3
blas_num_threads: int = 1
use_mp: bool = True
single_speaker: bool = False
temporary_directory: pathlib.Path = get_temporary_directory()
github_token: typing.Optional[str] = None
def __getitem__(self, item):
"""Get key from profile"""
return getattr(self, item)
[docs]
def update(self, data: Union[Dict[str, Any], click.Context]) -> None:
"""
Update configuration from new data
Parameters
----------
data: typing.Union[dict[str, typing.Any], :class:`click.Context`]
Parameters to update
"""
for k, v in data.items():
if k == "temp_directory":
k = "temporary_directory"
if k == "temporary_directory":
v = pathlib.Path(v)
if v is None:
continue
if hasattr(self, k):
setattr(self, k, v)
[docs]
class MfaConfiguration:
"""
Global MFA configuration class
"""
def __init__(self):
self.current_profile_name = os.getenv(MFA_PROFILE_VARIABLE, "global")
self.config_path = generate_config_path()
self.global_profile = MfaProfile()
self.profiles: Dict[str, MfaProfile] = {}
self.profiles["global"] = self.global_profile
if not os.path.exists(self.config_path):
self.save()
else:
self.load()
def __getattr__(self, item):
"""Get key from current profile"""
if hasattr(self.current_profile, item):
return getattr(self.current_profile, item)
def __getitem__(self, item):
"""Get key from current profile"""
if hasattr(self.current_profile, item):
return getattr(self.current_profile, item)
@property
def current_profile(self) -> MfaProfile:
"""Name of the current :class:`~montreal_forced_aligner.config.MfaProfile`"""
self.current_profile_name = os.getenv(MFA_PROFILE_VARIABLE, "global")
if self.current_profile_name not in self.profiles:
self.profiles[self.current_profile_name] = MfaProfile()
self.profiles[self.current_profile_name].update(dataclassy.asdict(self.global_profile))
return self.profiles[self.current_profile_name]
[docs]
def save(self) -> None:
"""Save MFA configuration"""
global_configuration_file = generate_config_path()
data = dataclassy.asdict(self.global_profile)
data["profiles"] = {
k: dataclassy.asdict(v) for k, v in self.profiles.items() if k != "global"
}
with mfa_open(global_configuration_file, "w") as f:
yaml.dump(data, f)
[docs]
def load(self) -> None:
"""Load MFA configuration"""
with mfa_open(self.config_path, "r") as f:
data = yaml.load(f, Loader=yaml.Loader)
for name, p in data.pop("profiles", {}).items():
self.profiles[name] = MfaProfile()
self.profiles[name].update(p)
self.global_profile.update(data)
if (
self.current_profile_name not in self.profiles
and self.current_profile_name != "global"
):
self.profiles[self.current_profile_name] = MfaProfile()
self.profiles[self.current_profile_name].update(data)
GLOBAL_CONFIG = MfaConfiguration()
MEMORY = joblib.Memory(
location=os.path.join(get_temporary_directory(), "joblib_cache"),
verbose=4 if GLOBAL_CONFIG.current_profile.verbose else 0,
bytes_limit=GLOBAL_CONFIG.current_profile.bytes_limit,
)