"""Implements the access to a "slurm database", stored in a HDF5 file.
HDF5 files are locked during access, which provides a convenient method to
not have several process modifying the same database easily
"""
import copy
import datetime
import getpass
import os
import time
from enum import IntEnum
from pathlib import Path
from typing import Any, Dict, List
import h5py
import numpy as np
[docs]
class JobState(IntEnum):
COMPLETED = 1
RUNNING = 2
PENDING = 3
CANCELLED = 4
FAILED = 5
[docs]
class JobReason(IntEnum):
NOREASON = 1 # sometimes there are no reason: job is running for instance
WaitingForScheduling = 2
Dependency = 3
DependencyNeverSatisfied = 4
NonZeroExitCode = 5
JobLaunchFailure = 6
_DB_DTYPE = np.dtype(
[
("PID", np.int64),
("NAME", np.dtype("S128")),
("USER", np.dtype("S128")),
("ACCOUNT", np.dtype("S128")),
("PARTITION", np.dtype("S128")),
("RESERVATION", np.dtype("S128")),
("NODELIST", np.dtype("S16384")),
("TIME", np.int64),
("START_TIME", np.float64),
("CMD", np.dtype("S16384")),
("STATE", np.int64),
("REASON", np.int64),
("EXIT_CODE", np.int64),
]
)
DB_DEFAULTS = {
"PID": -1,
"NAME": "wrap",
"USER": getpass.getuser(),
"ACCOUNT": getpass.getuser(),
"PARTITION": "",
"RESERVATION": "",
"NODELIST": "mocknode1",
"TIME": 0,
"START_TIME": datetime.datetime.now().timestamp(),
"CMD": "",
"STATE": JobState.PENDING,
"REASON": JobReason.WaitingForScheduling,
"EXIT_CODE": np.iinfo(np.int16).max,
}
[docs]
def find_db_file() -> Path:
"""Return a path to the location of the mock database hdf5 file.
The database file location is searched among various locations typically available.
Returns
-------
Path
Path to the mock database hdf5 file.
Raises
------
FileNotFoundError
If a suitable location for the database file could not be found.
"""
possible_locations = [Path("/tmp"), Path("/var/tmp"), Path(os.environ["HOME"]) / ".local", Path(".")]
mock_slurm_db = Path("mock_slurm_db.h5")
for p in possible_locations:
# retrieve the user permission for the directory
# if it is 7, we can read, write and execute it so we can open the folder and write inside
if p.exists() and int(oct(p.stat().st_mode)[-3]) == 7:
return p / mock_slurm_db
raise FileNotFoundError("Could not find a location to write mock slurm DB.")
[docs]
def open_file_retry_on_locked(file: Path, mode: str = "a", nb_retries: int = 40, wait_s: float = 0.01) -> h5py.File:
"""Open `file` as an HDF5 file in `mode`.
Since the HDF5 files are locked when another process access them, this function tries to
open the file a certain number of time.
Parameters
----------
file : Path
Path to the file to open
nb_retries : int, optional
Number of times to try to open the file, by default 40
wait_s : float, optional
Amount of time to wait between 2 attempts to open the file, in seconds, by default 0.1
Returns
-------
h5py.File
Open file handle to `file`.
"""
# Failing to open a file on fefs doesn't necessarily mean we won't succeed next time !
for _ in range(nb_retries - 1):
try:
f = h5py.File(file, mode)
return f
except BlockingIOError:
time.sleep(wait_s)
else:
# try 1 last time, let tables error raise if failing again
return h5py.File(file, mode)
[docs]
def clear_db():
"""Deletes the database file."""
db_file = find_db_file()
if db_file.exists():
print("Deleting db at {}".format(db_file))
db_file.unlink()
else:
print("No file to delete.")
[docs]
def get_db_file_handle(db_file: Path) -> h5py.File:
"""Open the database HDF5 file in append mode.
On success, the returned h5py.File contains a dataset with the expected dtypes of the database.
Parameters
----------
db_file : Path
Path to the database file to open
Returns
-------
h5py.File
File handle to the database
"""
if not db_file.exists():
db = np.empty(
dtype=_DB_DTYPE,
shape=(0,),
)
f = open_file_retry_on_locked(db_file)
f.create_dataset("SLURM_DB", data=db, maxshape=(None,))
return f
else:
return open_file_retry_on_locked(db_file)
[docs]
def get_db(db_file: h5py.File) -> h5py.Dataset:
"""Get the Database as a dataset in the HDF5 file
Parameters
----------
db_file : h5py.File
Opened database file handle
Returns
-------
h5py.Dataset
Dataset storing the database
"""
return db_file["SLURM_DB"]
[docs]
def update_with_default_value(db_dict: Dict) -> Dict:
"""Update `db_dict` missing fields with the default value in the database
Parameters
----------
db_dict : Dict
dict containing a database row data, possibly missing some columns to be filled with default values
Returns
-------
Dict
dict with all fields expected by the database present, with `dict` values if present, otherwise default values
"""
default_dict = copy.deepcopy(DB_DEFAULTS)
default_dict.update(db_dict)
return default_dict
[docs]
def append_job(db_file: h5py.File, **kwargs) -> int:
"""Append a job to the database
Parameters
----------
db_file : h5py.File
Opened file handle to the database
Returns
-------
int
Index of the job that was appended
"""
dataset = get_db(db_file)
dataset.resize(dataset.shape[0] + 1, axis=0)
job_data = np.empty(dtype=dataset.dtype, shape=(1,))
for k, v in update_with_default_value(kwargs).items():
job_data[k] = v
dataset[-1] = job_data
return dataset.shape[0] - 1
[docs]
def update_db_value(db_file: h5py.File, index: int, key: str, value: Any):
"""Update the `value` of `key` in the database, at `index`
Parameters
----------
db_file : h5py.File
Opened file handle to the database HDF5 file
index : int
Index of the row to update
key : str
Field to update
value : Any
New value for the field
"""
dataset = get_db(db_file)
update_value = dataset[index]
update_value[key] = value
dataset[index] = update_value
[docs]
def get_filtered_DB_mask(db_file: h5py.File, fields_values: Dict[str, str | List[str]]) -> np.ndarray:
"""Get a mask selecting the DB rows where the field values are equal to `fields_values` values.
Parameters
----------
db_file : h5py.File
Opened file handle to the database HDF5 file
fields_values : Dict[str, str | List[str]]
Map from fields to allowed fields values. Rows where the fields value is not equal to
one of the field values are not selected.
Key: field name, eg "NAME", "USER"
values: field value, eg "Robert", ["Robert", "Roberta"]
Returns
-------
np.ndarray
Index mask array, True where the row's fields are equal to the `field_values`.
"""
db = get_db(db_file)
total_mask = np.ones(shape=(db.shape[0],), dtype=bool)
for field, values in fields_values.items():
mask = np.zeros(shape=(db.shape[0],), dtype=bool)
value_list = values if isinstance(values, list) else [values]
for v in value_list:
mask |= db[field] == v
total_mask &= mask
return total_mask