"""
Abstract Base Classes
=====================
"""
from __future__ import annotations
import abc
import contextlib
import logging
import os
import re
import shutil
import subprocess
import sys
import time
import traceback
import typing
from pathlib import Path
import requests
import sqlalchemy
import yaml
from sqlalchemy.orm import scoped_session, sessionmaker
from montreal_forced_aligner import config
from montreal_forced_aligner.data import MfaArguments, WorkflowType
from montreal_forced_aligner.db import CorpusWorkflow, MfaSqlBase
from montreal_forced_aligner.exceptions import (
DatabaseError,
KaldiProcessingError,
MultiprocessingError,
)
from montreal_forced_aligner.helper import MfaYamlDumper, comma_join, load_configuration, mfa_open
__all__ = [
"MfaModel",
"MfaWorker",
"TopLevelMfaWorker",
"MetaDict",
"DatabaseMixin",
"FileExporterMixin",
"ModelExporterMixin",
"TemporaryDirectoryMixin",
"AdapterMixin",
"TrainerMixin",
"PhoneRemapperMixin",
"KaldiFunction",
]
# Configuration types
MetaDict = typing.Dict[str, typing.Any]
logger = logging.getLogger("mfa")
[docs]
class KaldiFunction(metaclass=abc.ABCMeta):
"""
Abstract class for running Kaldi functions
"""
def __init__(self, args: MfaArguments):
self.args = args
self.db_string = None
self._session = None
if isinstance(self.args.session, str):
self.db_string = self.args.session
else:
self._session = self.args.session
self.job_name = self.args.job_name
self.log_path = self.args.log_path
self.callback = None
@contextlib.contextmanager
def session(self):
if self._session is not None:
with self._session() as session:
yield session
else:
db_engine = sqlalchemy.create_engine(self.db_string)
with sqlalchemy.orm.Session(db_engine) as session:
yield session
[docs]
def run(self):
"""Run the function, calls subclassed object's ``_run`` with error handling"""
try:
if self._session is not None:
config.USE_THREADING = True
else:
config.USE_THREADING = False
self._run()
except Exception:
exc_type, exc_value, exc_traceback = sys.exc_info()
error_text = "\n".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
raise MultiprocessingError(self.job_name, error_text)
def _run(self) -> None:
"""Internal logic for running the worker"""
pass
[docs]
def check_call(self, proc: subprocess.Popen):
"""
Check whether a subprocess successfully completed
Parameters
----------
proc: subprocess.Popen
Subprocess to check
Raises
------
:class:`~montreal_forced_aligner.exceptions.KaldiProcessingError`
If there was an error running the subprocess
"""
if proc.returncode is None:
proc.wait()
if proc.returncode != 0:
raise KaldiProcessingError([self.log_path])
[docs]
class TemporaryDirectoryMixin(metaclass=abc.ABCMeta):
"""
Abstract mixin class for MFA temporary directories
"""
def __init__(
self,
**kwargs,
):
super().__init__(**kwargs)
self._corpus_output_directory = None
self._dictionary_output_directory = None
self._language_model_output_directory = None
self._acoustic_model_output_directory = None
self._g2p_model_output_directory = None
self._ivector_extractor_output_directory = None
self._current_workflow = None
@property
@abc.abstractmethod
def identifier(self) -> str:
"""Identifier to use in creating the temporary directory"""
...
@property
@abc.abstractmethod
def data_source_identifier(self) -> str:
"""Identifier for the data source (generally the corpus being used)"""
...
@property
@abc.abstractmethod
def output_directory(self) -> Path:
"""Root temporary directory"""
...
[docs]
def clean_working_directory(self) -> None:
"""Clean up previous runs"""
shutil.rmtree(self.output_directory, ignore_errors=True)
@property
def corpus_output_directory(self) -> Path:
"""Temporary directory containing all corpus information"""
if self._corpus_output_directory:
return self._corpus_output_directory
return self.output_directory.joinpath(f"{self.data_source_identifier}")
@corpus_output_directory.setter
def corpus_output_directory(self, directory: Path) -> None:
self._corpus_output_directory = directory
@property
def dictionary_output_directory(self) -> Path:
"""Temporary directory containing all dictionary information"""
if self._dictionary_output_directory:
return self._dictionary_output_directory
return self.output_directory.joinpath("dictionary")
@property
def model_output_directory(self) -> Path:
"""Temporary directory containing all dictionary information"""
return self.output_directory.joinpath("models")
@dictionary_output_directory.setter
def dictionary_output_directory(self, directory: Path) -> None:
self._dictionary_output_directory = directory
@property
def language_model_output_directory(self) -> Path:
"""Temporary directory containing all dictionary information"""
if self._language_model_output_directory:
return self._language_model_output_directory
return self.model_output_directory.joinpath("language_model")
@language_model_output_directory.setter
def language_model_output_directory(self, directory: Path) -> None:
self._language_model_output_directory = directory
@property
def acoustic_model_output_directory(self) -> Path:
"""Temporary directory containing all dictionary information"""
if self._acoustic_model_output_directory:
return self._acoustic_model_output_directory
return self.model_output_directory.joinpath("acoustic_model")
@acoustic_model_output_directory.setter
def acoustic_model_output_directory(self, directory: Path) -> None:
self._acoustic_model_output_directory = directory
[docs]
class DatabaseMixin(TemporaryDirectoryMixin, metaclass=abc.ABCMeta):
"""
Abstract class for mixing in database functionality
"""
def __init__(
self,
**kwargs,
):
super().__init__(**kwargs)
self._db_engine = None
self._db_path = None
self._session = None
self.database_initialized = False
def cleanup_connections(self) -> None:
if getattr(self, "_session", None) is not None:
self._session.remove()
del self._session
self._session = None
if getattr(self, "_db_engine", None) is not None:
self._db_engine.dispose()
del self._db_engine
self._db_engine = None
[docs]
def delete_database(self) -> None:
"""
Reset all schemas
"""
if config.USE_POSTGRES:
MfaSqlBase.metadata.drop_all(self.db_engine)
elif self.db_path.exists():
os.remove(self.db_path)
[docs]
def initialize_database(self) -> None:
"""
Initialize the database with database schema
"""
if self.database_initialized:
return
from montreal_forced_aligner.command_line.utils import check_databases
if config.USE_POSTGRES:
exist_check = True
try:
check_databases(self.identifier)
except Exception:
try:
subprocess.check_call(
[
"createdb",
f"--host={config.database_socket()}",
self.identifier,
],
stderr=subprocess.DEVNULL,
stdout=subprocess.DEVNULL,
)
except Exception:
raise DatabaseError(
f"There was an error connecting to the {config.CURRENT_PROFILE_NAME} MFA database server "
f"at {config.database_socket()}. "
"Please ensure the server is initialized (mfa server init) or running (mfa server start)"
)
exist_check = False
else:
exist_check = self.db_path.exists()
self.database_initialized = True
if config.CLEAN or getattr(self, "dirty", False):
self.clean_working_directory()
if exist_check:
if config.CLEAN or getattr(self, "dirty", False):
self.delete_database()
else:
return
os.makedirs(self.output_directory, exist_ok=True)
if config.USE_POSTGRES:
with self.db_engine.connect() as conn:
conn.execute(sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS vector"))
conn.execute(sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS pg_trgm"))
conn.execute(sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS pg_stat_statements"))
conn.execute(sqlalchemy.text(f"select setseed({config.SEED / 32768})"))
conn.commit()
MfaSqlBase.metadata.create_all(self.db_engine)
@property
def db_engine(self) -> sqlalchemy.engine.Engine:
"""Database engine"""
if self._db_engine is None:
self._db_engine = self.construct_engine()
return self._db_engine
def get_next_primary_key(self, database_table):
with self.session() as session:
pk = session.query(sqlalchemy.func.max(database_table.id)).scalar()
if not pk:
pk = 0
return pk + 1
def create_new_current_workflow(self, workflow_type: WorkflowType, name: str = None):
from montreal_forced_aligner.db import CorpusWorkflow
with self.session() as session:
if not name:
name = workflow_type.name
self._current_workflow = name
session.query(CorpusWorkflow).update({"current": False})
new_workflow = (
session.query(CorpusWorkflow).filter(CorpusWorkflow.name == name).first()
)
if not new_workflow:
new_workflow = CorpusWorkflow(
name=name,
workflow_type=workflow_type,
working_directory=os.path.join(self.output_directory, name),
current=True,
)
log_dir = os.path.join(new_workflow.working_directory, "log")
os.makedirs(log_dir, exist_ok=True)
session.add(new_workflow)
else:
new_workflow.current = True
session.commit()
def set_current_workflow(self, identifier):
from montreal_forced_aligner.db import CorpusWorkflow
with self.session() as session:
session.query(CorpusWorkflow).update({CorpusWorkflow.current: False})
wf = session.query(CorpusWorkflow).filter(CorpusWorkflow.name == identifier).first()
wf.current = True
self._current_workflow = identifier
session.commit()
@property
def current_workflow(self) -> CorpusWorkflow:
from montreal_forced_aligner.db import CorpusWorkflow
with self.session() as session:
wf = (
session.query(CorpusWorkflow)
.filter(CorpusWorkflow.current == True) # noqa
.first()
)
return wf
@property
def db_path(self) -> Path:
"""Connection path for sqlite database"""
return self.output_directory.joinpath(f"{self.identifier}.db")
@property
def db_string(self) -> str:
"""Connection string for the database"""
if config.USE_POSTGRES:
return f"postgresql+psycopg2://@/{self.identifier}?host={config.database_socket()}"
else:
return f"sqlite:///{self.db_path}"
[docs]
def construct_engine(self, **kwargs) -> sqlalchemy.engine.Engine:
"""
Construct a database engine
Parameters
----------
same_thread: bool, optional
Flag for whether to enforce checking access on different threads, defaults to True
read_only: bool, optional
Flag for whether the database engine should be created as read-only, defaults to False
Returns
-------
:class:`~sqlalchemy.engine.Engine`
SqlAlchemy engine
"""
db_string = self.db_string
if not config.USE_POSTGRES:
if kwargs.pop("read_only", False):
db_string += "?mode=ro&nolock=1&uri=true"
kwargs["pool_size"] = config.NUM_JOBS + 10
kwargs["max_overflow"] = config.NUM_JOBS + 10
e = sqlalchemy.create_engine(
db_string,
**kwargs,
)
return e
@property
def session(self) -> sqlalchemy.orm.scoped_session:
"""
Construct database session
Parameters
----------
**kwargs
Keyword arguments to pass to the Session
Returns
-------
:class:`~sqlalchemy.orm.sessionmaker`
SqlAlchemy session
"""
if self._session is None:
self._session = scoped_session(
sessionmaker(bind=self.db_engine, expire_on_commit=False)
)
return self._session
[docs]
class MfaWorker(metaclass=abc.ABCMeta):
"""
Abstract class for MFA workers
Attributes
----------
dirty: bool
Flag for whether an error was encountered in processing
"""
def __init__(
self,
**kwargs,
):
super().__init__(**kwargs)
self.dirty = False
[docs]
@classmethod
def get_configuration_parameters(cls) -> typing.Dict[str, typing.Type]:
"""
Get the types of parameters available to be configured
Returns
-------
dict[str, Type]
Dictionary of parameter names and their types
"""
mapping = {typing.Dict: dict, typing.Tuple: tuple, typing.List: list, typing.Set: set}
configuration_params = {}
for t, ty in typing.get_type_hints(cls.__init__).items():
configuration_params[t] = ty
try:
if ty.__origin__ == typing.Union:
configuration_params[t] = ty.__args__[0]
except AttributeError:
pass
for c in cls.mro():
try:
for t, ty in typing.get_type_hints(c.__init__).items():
configuration_params[t] = ty
try:
if ty.__origin__ == typing.Union:
configuration_params[t] = ty.__args__[0]
except AttributeError:
pass
except AttributeError:
pass
for t, ty in configuration_params.items():
for v in mapping.values():
try:
if ty.__origin__ == v:
configuration_params[t] = v
break
except AttributeError:
break
return configuration_params
@property
def configuration(self) -> MetaDict:
"""Configuration parameters"""
return {
"dirty": self.dirty,
}
@property
@abc.abstractmethod
def working_directory(self) -> Path:
"""Current working directory"""
...
@property
def working_log_directory(self) -> Path:
"""Current working log directory"""
return self.working_directory.joinpath("log")
@property
@abc.abstractmethod
def data_directory(self) -> Path:
"""Data directory"""
...
[docs]
class TopLevelMfaWorker(MfaWorker, TemporaryDirectoryMixin, metaclass=abc.ABCMeta):
"""
Abstract mixin for top-level workers in MFA. This class holds properties about the larger workflow run.
Parameters
----------
num_jobs: int
Number of jobs and processes to use
clean: bool
Flag for whether to remove any old files in the work directory
"""
nullable_fields = [
"punctuation",
"compound_markers",
"clitic_markers",
"quote_markers",
"word_break_markers",
]
def __init__(
self,
**kwargs,
):
kwargs, skipped = type(self).extract_relevant_parameters(kwargs)
super().__init__(**kwargs)
self.initialized = False
self.start_time = time.time()
self.setup_logger()
if skipped:
logger.warning(f"Skipped the following configuration keys: {comma_join(skipped)}")
[docs]
def cleanup_logger(self):
"""Ensure that loggers are cleaned up on delete"""
logger = logging.getLogger("mfa")
handlers = logger.handlers[:]
for handler in handlers:
if isinstance(handler, logging.FileHandler):
handler.close()
logger.removeHandler(handler)
[docs]
def setup(self) -> None:
"""Setup for worker"""
self.check_previous_run()
if hasattr(self, "initialize_database"):
self.initialize_database()
if hasattr(self, "inspect_database"):
self.inspect_database()
@property
def working_directory(self) -> Path:
"""Alias for a folder that contains worker information, separate from the data directory"""
return self.output_directory.joinpath(self._current_workflow)
[docs]
@classmethod
def parse_args(
cls,
args: typing.Optional[typing.Dict[str, typing.Any]],
unknown_args: typing.Optional[typing.List[str]],
) -> MetaDict:
"""
Class method for parsing configuration parameters from command line arguments
Parameters
----------
args: dict[str, Any]
Parsed arguments
unknown_args: list[str]
Optional list of arguments that were not parsed
Returns
-------
dict[str, Any]
Dictionary of specified configuration parameters
"""
from montreal_forced_aligner.data import Language
param_types = cls.get_configuration_parameters()
params = {}
unknown_dict = {}
if unknown_args:
for i, a in enumerate(unknown_args):
if not a.startswith("--"):
continue
name = a.replace("--", "")
if name not in param_types:
continue
if i == len(unknown_args) - 1 or unknown_args[i + 1].startswith("--"):
val = True
else:
val = unknown_args[i + 1]
unknown_dict[name] = val
for name, param_type in param_types.items():
if (name.endswith("_directory") and name != "audio_directory") or (
name.endswith("_path")
and name not in {"rules_path", "phone_groups_path", "topology_path"}
):
continue
if args is not None and name in args and args[name] is not None:
if param_type == Language:
params[name] = param_type[args[name]]
else:
params[name] = param_type(args[name])
elif name in unknown_dict:
if param_type == Language:
params[name] = param_type[unknown_dict[name]]
elif param_type == bool and not isinstance(unknown_dict[name], bool):
if unknown_dict[name].lower() == "false":
params[name] = False
else:
params[name] = param_type(unknown_dict[name])
return params
[docs]
@classmethod
def parse_parameters(
cls,
config_path: typing.Optional[Path] = None,
args: typing.Optional[typing.Dict[str, typing.Any]] = None,
unknown_args: typing.Optional[typing.Iterable[str]] = None,
) -> MetaDict:
"""
Parse configuration parameters from a config file and command line arguments
Parameters
----------
config_path: :class:`~pathlib.Path`, optional
Path to yaml configuration file
args: dict[str, Any]
Parsed arguments
unknown_args: list[str]
Optional list of arguments that were not parsed
Returns
-------
dict[str, Any]
Dictionary of specified configuration parameters
"""
global_params = {}
if config_path and os.path.exists(config_path):
data = load_configuration(config_path)
for k, v in data.items():
if v is None and k in cls.nullable_fields:
v = []
global_params[k] = v
global_params.update(cls.parse_args(args, unknown_args))
return global_params
@property
def worker_config_path(self) -> str:
"""Path to worker's configuration in the working directory"""
return os.path.join(self.output_directory, f"{self.data_source_identifier}.yaml")
[docs]
def cleanup(self) -> None:
"""
Clean up loggers and output final message for top-level workers
"""
try:
if hasattr(self, "cleanup_connections"):
self.cleanup_connections()
if self.dirty:
logger.error("There was an error in the run, please see the log.")
else:
logger.info(f"Done! Everything took {time.time() - self.start_time:.3f} seconds")
if config.FINAL_CLEAN:
logger.debug(
"Cleaning up temporary files, use the --no_final_clean flag to keep temporary files."
)
if hasattr(self, "delete_database"):
if config.USE_POSTGRES:
proc = subprocess.run(
[
"dropdb",
f"--host={config.database_socket()}",
"--if-exists",
"--force",
self.identifier,
],
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
check=True,
encoding="utf-8",
)
logger.debug(f"Stdout: {proc.stdout}")
logger.debug(f"Stderr: {proc.stderr}")
else:
self.delete_database()
self.clean_working_directory()
self.save_worker_config()
self.cleanup_logger()
except (NameError, ValueError): # already cleaned up
pass
[docs]
def save_worker_config(self) -> None:
"""Export worker configuration to its working directory"""
if not os.path.exists(self.output_directory):
return
with mfa_open(self.worker_config_path, "w") as f:
yaml.dump(self.configuration, f, Dumper=MfaYamlDumper)
def _validate_previous_configuration(self, conf: MetaDict) -> None:
"""
Validate the current configuration against a previous configuration
Parameters
----------
conf: dict[str, Any]
Previous run's configuration
"""
from montreal_forced_aligner.utils import get_mfa_version
self.dirty = False
current_version = get_mfa_version()
if not config.DEBUG and conf.get("version", current_version) != current_version:
logger.debug(
f"Previous run was on {conf['version']} version (new run: {current_version})"
)
self.dirty = True
[docs]
def check_previous_run(self) -> None:
"""
Check whether a previous run has any conflicting settings with the current run.
Returns
-------
bool
Flag for whether the current run is compatible with the previous one
"""
if not os.path.exists(self.worker_config_path):
return True
try:
conf = load_configuration(self.worker_config_path)
self._validate_previous_configuration(conf)
if not config.CLEAN and self.dirty:
logger.warning(
"The previous run had a different configuration than the current, which may cause issues."
" Please see the log for details or use --clean flag if issues are encountered."
)
except yaml.error.YAMLError:
logger.warning("The previous run's configuration could not be loaded.")
return False
@property
def identifier(self) -> str:
"""Combined identifier of the data source and workflow"""
return self.data_source_identifier
@property
def output_directory(self) -> Path:
"""Root temporary directory to store all of this worker's files"""
return config.TEMPORARY_DIRECTORY.joinpath(self.identifier)
@property
def log_file(self) -> Path:
"""Path to the worker's log file"""
return self.output_directory.joinpath(f"{self.data_source_identifier}.log")
[docs]
def setup_logger(self) -> None:
"""
Construct a logger for a command line run
"""
from montreal_forced_aligner.helper import configure_logger
from montreal_forced_aligner.utils import get_mfa_version
current_version = get_mfa_version()
# Remove previous directory if versions are different
clean = False
if os.path.exists(self.worker_config_path):
conf = load_configuration(self.worker_config_path)
if conf.get("version", current_version) != current_version:
clean = True
os.makedirs(self.output_directory, exist_ok=True)
configure_logger("mfa", log_file=self.log_file)
logger = logging.getLogger("mfa")
if config.VERBOSE:
try:
response = requests.get(
"https://api.github.com/repos/MontrealCorpusTools/Montreal-Forced-Aligner/releases/latest"
)
latest_version = response.json()["tag_name"].replace("v", "")
if current_version < latest_version:
logger.debug(
f"You are currently running an older version of MFA ({current_version}) than the latest available ({latest_version}). "
f"To update, please run mfa_update."
)
except Exception:
pass
if re.search(r"\d+\.\d+\.\d+a", current_version) is not None:
logger.debug(
"Please be aware that you are running an alpha version of MFA. If you would like to install a more "
"stable version, please visit https://montreal-forced-aligner.readthedocs.io/en/latest/installation.html#installing-older-versions-of-mfa",
)
logger.debug(f"Beginning run for {self.data_source_identifier}")
logger.debug(f'Using "{config.CURRENT_PROFILE_NAME}" profile')
if config.USE_MP:
logger.debug(f"Using multiprocessing with {config.NUM_JOBS}")
else:
logger.debug(f"NOT using multiprocessing with {config.NUM_JOBS}")
logger.debug(f"Set up logger for MFA version: {current_version}")
if clean or config.CLEAN:
logger.debug("Cleaned previous run")
[docs]
class ExporterMixin(metaclass=abc.ABCMeta):
"""
Abstract mixin class for exporting any kind of file
Parameters
----------
overwrite: bool
Flag for whether to overwrite the specified path if a file exists
"""
def __init__(self, overwrite: bool = False, **kwargs):
self.overwrite = overwrite
super().__init__(**kwargs)
[docs]
class ModelExporterMixin(ExporterMixin, metaclass=abc.ABCMeta):
"""
Abstract mixin class for exporting MFA models
"""
@property
@abc.abstractmethod
def meta(self) -> MetaDict:
"""Training configuration parameters"""
...
[docs]
@abc.abstractmethod
def export_model(self, output_model_path: Path) -> None:
"""
Abstract method to export an MFA model
Parameters
----------
output_model_path: :class:`~pathlib.Path`
Path to export model
"""
...
[docs]
class FileExporterMixin(ExporterMixin, metaclass=abc.ABCMeta):
"""
Abstract mixin class for exporting TextGrid and text files
"""
[docs]
@abc.abstractmethod
def export_files(self, output_directory: str) -> None:
"""
Export files to an output directory
Parameters
----------
output_directory: str
Directory to export to
"""
...
[docs]
class TrainerMixin(ModelExporterMixin):
"""
Abstract mixin class for MFA trainers
Parameters
----------
num_iterations: int
Number of training iterations
model_version: str
Override for model version
Attributes
----------
iteration: int
Current iteration
"""
def __init__(self, num_iterations: int = 40, model_version: str = None, **kwargs):
super().__init__(**kwargs)
self.iteration: int = 0
self.num_iterations = num_iterations
self.model_version = model_version
[docs]
@abc.abstractmethod
def initialize_training(self) -> None:
"""Initialize training"""
...
[docs]
@abc.abstractmethod
def train(self) -> None:
"""Perform training"""
...
[docs]
@abc.abstractmethod
def train_iteration(self) -> None:
"""Run one training iteration"""
...
[docs]
@abc.abstractmethod
def finalize_training(self) -> None:
"""Finalize training"""
...
[docs]
class AdapterMixin(ModelExporterMixin):
"""
Abstract class for MFA model adaptation
"""
[docs]
@abc.abstractmethod
def adapt(self) -> None:
"""Perform adaptation"""
...
[docs]
class MfaModel(abc.ABC):
"""Abstract class for MFA models"""
extensions: typing.List[str]
model_type = "base_model"
[docs]
@classmethod
def pretrained_directory(cls) -> Path:
"""Directory that pretrained models are saved in"""
from .config import get_temporary_directory
path = get_temporary_directory().joinpath("pretrained_models", cls.model_type)
path.mkdir(parents=True, exist_ok=True)
return path
[docs]
@classmethod
def get_available_models(cls) -> typing.List[str]:
"""
Get a list of available models for a given model type
Returns
-------
list[str]
List of model names
"""
if not cls.pretrained_directory().exists():
return []
available = []
for f in cls.pretrained_directory().iterdir():
if cls.valid_extension(f):
available.append(f.stem)
return available
[docs]
@classmethod
def get_pretrained_path(cls, name: str, enforce_existence: bool = True) -> Path:
"""
Generate a path to a pretrained model based on its name and model type
Parameters
----------
name: str
Name of model
enforce_existence: bool
Flag to return None if the path doesn't exist, defaults to True
Returns
-------
Path to model
"""
return cls.generate_path(cls.pretrained_directory(), name, enforce_existence)
[docs]
@classmethod
@abc.abstractmethod
def valid_extension(cls, filename: Path) -> bool:
"""Check whether a file has a valid extensions"""
...
[docs]
@classmethod
@abc.abstractmethod
def generate_path(
cls, root: Path, name: str, enforce_existence: bool = True
) -> typing.Optional[Path]:
"""Generate a path from a root directory"""
...
[docs]
@abc.abstractmethod
def pretty_print(self) -> None:
"""Print the model's meta data"""
...
@property
@abc.abstractmethod
def meta(self) -> MetaDict:
"""Metadata for the model"""
...
class PhoneRemapperMixin(metaclass=abc.ABCMeta):
"""
Abstract mixin class for remapping phones
"""
def __init__(
self,
phone_mapping_path: Path,
**kwargs,
):
super().__init__(**kwargs)
self.phone_mapping_path = phone_mapping_path
self.phone_remapping = {}
def load_mapping(self) -> None:
with mfa_open(self.phone_mapping_path, "r") as f:
self.phone_remapping = yaml.load(f, Loader=yaml.Loader)
for key, values in self.phone_remapping.items():
if not isinstance(values, list):
self.phone_remapping[key] = [values]