#!/usr/bin/env python
"""Implement a mock of the sbatch command of slurm.
The implementation mimics the sbatch API, while using subprocess to
start the processes on the host machine without slurm daemons.
Double fork is used to detach the subprocess, allowing the sbatch
mock to return while the created process is still running.
"""
import argparse
import os
import subprocess as sp
import time
from enum import Enum
from typing import Dict, List, Tuple
import sys
import logging
from mockslurm.process_db import (
JobReason,
JobState,
append_job,
find_db_file,
get_db,
get_db_file_handle,
update_db_value,
)
[docs]
class DependencyActions(Enum):
OK = 1
WAIT = 2
NEVER = 3
[docs]
def parse_dependency(job_dependency: str) -> Tuple[bool, List[Dict[str, List[int]]]]:
"""Parse a job's dependency specification string in sbatch format
Parameters
----------
job_dependency : str
sbatch job's dependency string. Examples: afterok:20:21,afterany:23
Returns
-------
Tuple[bool, List[Dict[str, List[int]]]]
Tuple 1st element is whether all dependencies must be satisfied for a job to start (True, "," in slurm str),
or if a single dependency is enough (False, "?" in slurm str). When a single dependency is present: True.
The list values are dictionaries with a single key been the dependency type ("afterok", "afterany", etc.)
and the values been the job IDs for these dependency
Example
-------
>>> parse_dependency("afterok:20:21?afterany:23")
(True, [{"afterok": [20, 21]}, {"afterany": [23]}])
"""
if "," in job_dependency and "?" in job_dependency:
raise ValueError('Multiple separator "?", ","in job dependency not supported.')
separator = "?" if "?" in job_dependency else "," # default is "," if only 1 dep
dependencies = []
while job_dependency:
if job_dependency[0] == separator:
job_dependency = job_dependency[1:]
dep_str = job_dependency.split(separator, maxsplit=1)[0]
job_dependency = job_dependency[len(dep_str) :]
split = dep_str.split(":")
dependencies.append({split[0]: [int(v) for v in split[1:]]})
return (separator == ",", dependencies)
[docs]
def check_dependency(job_dependency: List[Tuple[bool, Dict[str, List[int]]]]) -> str:
"""Check if a job's dependency are satisfied
Parameters
----------
job_dependency : Tuple[bool, List[Dict[str, List[int]]]]
Job dependencies. Each tuple in the list corresponds to a dependency specification.
See `mockslurm.mock_sbatch.parse_dependency`
Warning
-------
This function only supports the "afterok" dependency specification, anything else causes
a `NotImplementedError` to be raised.
Raise
-----
NotImplementedError
If the dependency type is not "afterok"
Returns
-------
DependencyActions
The action to follow based on the dependency evaluation, for instance to wait, to start job, etc.
"""
if not job_dependency[1]:
return DependencyActions.OK
deps_value = []
with get_db_file_handle(find_db_file()) as db_file:
db = get_db(db_file)
for dep in job_dependency[1]:
internal_deps_ok = True
for dep_type, job_idx in dep.items():
if dep_type != "afterok":
raise NotImplementedError("Dependency type {} is not implemented.".format(dep_type))
dep_state = db[job_idx if isinstance(job_idx, list) else [job_idx]]["STATE"]
if any(dep_state == JobState.FAILED) or any(dep_state == JobState.CANCELLED):
return DependencyActions.NEVER
internal_deps_ok &= all(dep_state == JobState.COMPLETED)
deps_value.append(internal_deps_ok)
combined_dep = all(deps_value) if job_dependency[0] else any(deps_value)
return DependencyActions.OK if combined_dep else DependencyActions.WAIT
[docs]
def launch_job(
cmd: str, stdout: str, stderr: str, job_idx: int, job_dependency: List[Tuple[bool, Dict[str, List[int]]]]
):
"""Runs `cmd` as a detached subprocess.
The return code of `cmd` is retrieved by the detached subprocess and updated in the DB.
Parameters
----------
cmd : str
Command to run, similar to the content of the --wrap= argument of sbatch
stdout : str
Path to a file where the output of `cmd` will be piped
stderr : str
Path to a file where the errors of `cmd` will be piped
job_idx : int
Index of the job in the DB
job_dependency : List[Tuple[bool, Dict[str, List[int]]]]
Job dependencies. Each tuple in the list corresponds to a dependency specification.
See `mockslurm.mock_sbatch.parse_dependency`
"""
logging.debug(
"lauch_job with args cmd: {}, stdout: {}, stderr: {}, job_idx: {}, job_dependency: {}".format(
cmd, stdout, stderr, job_idx, job_dependency
)
)
# Wait for dependencies to be ready
dependency_check = check_dependency(job_dependency)
if dependency_check == DependencyActions.WAIT:
logging.debug(
"Job {} dependency {} 1st check is WAIT, updating db Reason DEPENDENCY".format(job_idx, job_dependency)
)
with get_db_file_handle(find_db_file()) as db_file:
db = get_db(db_file)
update_db_value(db_file, job_idx, key="REASON", value=JobReason.Dependency)
while dependency_check == DependencyActions.WAIT:
time.sleep(0.25)
logging.debug("Job {} dependency {} check remains WAIT, sleeping".format(job_idx, job_dependency))
dependency_check = check_dependency(job_dependency)
# If not ok: do not start job and mark its state as FAILED
if dependency_check != DependencyActions.OK:
logging.debug(
"Job {} dependency check is {}, updating DB with REASON DependencyNeverSatisfied and STATE FAILED".format(
job_idx, dependency_check
)
)
with get_db_file_handle(find_db_file()) as db_file:
db = get_db(db_file)
update_db_value(db_file, job_idx, key="REASON", value=JobReason.DependencyNeverSatisfied)
update_db_value(db_file, job_idx, key="STATE", value=JobState.FAILED)
else:
logging.debug("Job {} dependency check is OK, checking job STATE".format(job_idx))
# lock the DB here, so STATE can not change before we start the process
with get_db_file_handle(find_db_file()) as db_file:
db = get_db(db_file)
if db[job_idx]["STATE"] == JobState.PENDING:
logging.debug("Job {} STATE is still pending, starting job".format(job_idx))
try:
stdout_f = open(stdout, "a")
stderr_f = sp.STDOUT if stderr == stdout else open(stderr, "a")
# Can not use shlex here to split command, otherwise we would split bash commands at each
# space so we use a single string with shell=True
p = sp.Popen(cmd, stdout=stdout_f, stderr=stderr_f, start_new_session=True, shell=True)
logging.debug(
"Job {} started with PID {}, cmd {}, stdout {}, stderr {}".format(
job_idx, p.pid, cmd, stdout_f, stderr_f
)
)
update_db_value(db_file, job_idx, key="PID", value=p.pid)
update_db_value(db_file, job_idx, key="STATE", value=JobState.RUNNING)
update_db_value(db_file, job_idx, key="REASON", value=JobReason.NOREASON)
logging.debug(
"Job {} updated db with PID {}, STATE RUNNING, REASON NOREASON".format(job_idx, p.pid)
)
except:
logging.debug("Job {} failed to start. Updating DB with STATE FAILED and REASON JobLaunchFailure")
update_db_value(db_file, job_idx, key="STATE", value=JobState.FAILED)
update_db_value(db_file, job_idx, key="REASON", value=JobReason.JobLaunchFailure)
logging.debug("Job {} starting error", exc_info=True)
raise
else:
logging.debug(
"Job {} STATE is {}. Do not start start job and exit".format(job_idx, db[job_idx]["STATE"])
)
# Process is not pending anymore, it must have been killed already, so we won't start it
return
logging.debug("Job {} is running, waiting its completion".format(job_idx))
# wait for process to be done
exit_code = p.wait()
logging.debug("Job {} completed with exit code {}".format(job_idx, exit_code))
# closing stdout and stderr file of job
stdout_f.close()
if stderr != stdout:
stderr_f.close()
logging.debug("Job {} closed stdout {} and stderr {}".format(job_idx, stdout, stderr))
# Update job state in db
with get_db_file_handle(find_db_file()) as db_file:
job_state = JobState.COMPLETED if exit_code == 0 else JobState.FAILED
job_reason = JobReason.NOREASON if exit_code == 0 else JobReason.NonZeroExitCode
update_db_value(db_file, job_idx, key="EXIT_CODE", value=exit_code)
update_db_value(db_file, job_idx, key="STATE", value=job_state)
update_db_value(db_file, job_idx, key="REASON", value=job_reason)
logging.debug(
"Job {} Updated db with EXIT_CODE {}, STATE {}, REASON {}".format(
job_idx, exit_code, JobState(job_state).name, JobReason(job_reason).name
)
)
[docs]
def main():
parser = argparse.ArgumentParser(
description="Slurm sbtach mock.", formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("--account", "-A", type=str, dest="ACCOUNT", help="user account", required=False)
parser.add_argument(
"--dependency",
"-d",
type=str,
default="",
dest="dependency",
help="Defer the start of this job until the specified dependencies have been satisfied",
)
parser.add_argument(
"--error", "-e", type=str, dest="stderr", help="error file of slurm job", default="slurm-%j.out"
)
parser.add_argument("--job-name", "-J", type=str, dest="NAME", help="job name", default="wrap")
parser.add_argument(
"--nodelist",
"-w",
type=str,
dest="NODELIST",
help="Request a specific list of hosts",
nargs="*",
)
parser.add_argument(
"--output", "-o", type=str, dest="stdout", help="output file of slurm job", default="slurm-%j.out"
)
parser.add_argument("--partition", "-p", type=str, dest="PARTITION", help="job partition", required=False)
parser.add_argument(
"--parsable",
action="store_true",
dest="parsable",
help="Outputs only the job id number. Errors will still be displayed",
)
parser.add_argument("--reservation", type=str, dest="RESERVATION", help="job reservation", required=False)
parser.add_argument(
"--wrap",
type=str,
dest="CMD",
help="Command to be executed",
required=True,
)
parser.add_argument(
"--mock_sbatch_debug",
action="store_true",
dest="debug_mode",
help="If provided, logs all actions of the sbatch mock in a mock_sbatch.log file",
default=False,
)
args, _ = parser.parse_known_args()
defined_args = {arg: value for arg, value in vars(args).items() if value is not None}
defined_args.pop("stdout")
defined_args.pop("stderr")
defined_args.pop("parsable")
defined_args.pop("dependency")
defined_args.pop("debug_mode")
# TODO: raise error if arguments to sbatch aren't valid (reservation, partition, nodelist ? etc.)
if args.debug_mode:
logging.basicConfig(
filename="mock_sbatch.log",
level=logging.DEBUG,
format="%(asctime)s %(levelname)s mock_sbatch %(pathname)s:%(lineno)s:%(funcName)s %(message)s",
)
logging.debug(
"\n".join(
["mock_sbatch called with args:"]
+ ["--{}: {}".format(arg, arg_value) for arg, arg_value in vars(args).items()]
)
)
with get_db_file_handle(find_db_file()) as db_file:
job_idx = append_job(db_file, **defined_args)
logging.debug("Appended job {} in DB with args {}".format(job_idx, defined_args))
# if parsable is set: print jobID
if args.parsable:
print(job_idx, end="", flush=True)
logging.debug("Printing job ID {} to stdout".format(job_idx))
# In order to create a process to run the sbatch command and exit sbatch without killing the child process or
# making it a zombie, we need to use the double fork technique used to create deamons:
# see https://stackoverflow.com/questions/473620/how-do-you-create-a-daemon-in-python
# https://stackoverflow.com/questions/881388/what-is-the-reason-for-performing-a-doube-fork-when-creating-a-daemon
# Double forking is used to create a new process B (that is attached to process group of A, its parent), then
# create a new process C and dettach it from A. The new process can then run and be cleaned up by process 1
# instead of A.
# There are a number of other things to consider to completely detach a process, such as closing all file handles,
# here we only do the minimum for the mock to work.
# See https://pagure.io/python-daemon/blob/main/f/src/daemon/daemon.py for a reference implementation of a daemon
# in python (old one, there is a pep associated that was never merged due to lack of interest)
logging.debug("Job {} Parent process will fork child".format(job_idx))
# fork a first time
pid_1st_fork = os.fork()
if pid_1st_fork == 0: # we are in the child process
logging.debug("Job {} Child process, creating new process session and becoming group leader".format(job_idx))
# create new process session for the child
# the 2nd fork process will be in this session, detached from the initial process
os.setsid()
logging.debug("Job {} Child process will fork grandchild".format(job_idx))
# second fork
pid_2nd_fork = os.fork()
if pid_2nd_fork == 0: # we are in the grandchild: run the sbatch job, then update db
logging.debug("Job {} Grandchild process re-derecting stdin, stdout and stderr".format(job_idx))
# redirect stdin, stdout and stderr to /dev/null (so we separate from child's file handles)
# (child's file handles are the same than the parent child handles...)
os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdin.fileno())
os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdout.fileno())
os.dup2(os.open(os.devnull, os.O_RDWR), sys.stderr.fileno())
logging.debug("Job {} Grandchild process calls launch_job".format(job_idx))
launch_job(
args.CMD,
args.stdout.replace("%j", str(job_idx).strip()),
args.stderr.replace("%j", str(job_idx).strip()),
job_idx,
parse_dependency(args.dependency),
)
# exit without clean up.
# it is REQUIRED to have the grandchild exit in this way if it is tested with pytest, otherwise pytest
# will run the program several times see https://github.com/pytest-dev/pytest/issues/12028
# Side effect: the stdio buffers will not be flushed (detached process can't output to stdio anyway)
os._exit(0)
else:
# exit the child without cleaning up file handlers or flushing buffers, because those are shared with parent
logging.debug("Job {} Child proces exiting without cleaning up.".format(job_idx))
os._exit(0)
else:
os.wait() # this waits for the child (not the grandchild!) to finish to clean it up
logging.debug("Job {} Parent process exiting after waiting on child process.".format(job_idx))
if __name__ == "__main__":
main()