Source code for mockslurm.mock_sbatch

#!/usr/bin/env python
"""Implement a mock of the sbatch command of slurm.

The implementation mimics the sbatch API, while using subprocess to
start the processes on the host machine without slurm daemons.

Double fork is used to detach the subprocess, allowing the sbatch
mock to return while the created process is still running.
"""

import argparse
import os
import subprocess as sp
import time
from enum import Enum
from typing import Dict, List, Tuple
import sys
import logging

from mockslurm.process_db import (
    JobReason,
    JobState,
    append_job,
    find_db_file,
    get_db,
    get_db_file_handle,
    update_db_value,
)


[docs] class DependencyActions(Enum): OK = 1 WAIT = 2 NEVER = 3
[docs] def parse_dependency(job_dependency: str) -> Tuple[bool, List[Dict[str, List[int]]]]: """Parse a job's dependency specification string in sbatch format Parameters ---------- job_dependency : str sbatch job's dependency string. Examples: afterok:20:21,afterany:23 Returns ------- Tuple[bool, List[Dict[str, List[int]]]] Tuple 1st element is whether all dependencies must be satisfied for a job to start (True, "," in slurm str), or if a single dependency is enough (False, "?" in slurm str). When a single dependency is present: True. The list values are dictionaries with a single key been the dependency type ("afterok", "afterany", etc.) and the values been the job IDs for these dependency Example ------- >>> parse_dependency("afterok:20:21?afterany:23") (True, [{"afterok": [20, 21]}, {"afterany": [23]}]) """ if "," in job_dependency and "?" in job_dependency: raise ValueError('Multiple separator "?", ","in job dependency not supported.') separator = "?" if "?" in job_dependency else "," # default is "," if only 1 dep dependencies = [] while job_dependency: if job_dependency[0] == separator: job_dependency = job_dependency[1:] dep_str = job_dependency.split(separator, maxsplit=1)[0] job_dependency = job_dependency[len(dep_str) :] split = dep_str.split(":") dependencies.append({split[0]: [int(v) for v in split[1:]]}) return (separator == ",", dependencies)
[docs] def check_dependency(job_dependency: List[Tuple[bool, Dict[str, List[int]]]]) -> str: """Check if a job's dependency are satisfied Parameters ---------- job_dependency : Tuple[bool, List[Dict[str, List[int]]]] Job dependencies. Each tuple in the list corresponds to a dependency specification. See `mockslurm.mock_sbatch.parse_dependency` Warning ------- This function only supports the "afterok" dependency specification, anything else causes a `NotImplementedError` to be raised. Raise ----- NotImplementedError If the dependency type is not "afterok" Returns ------- DependencyActions The action to follow based on the dependency evaluation, for instance to wait, to start job, etc. """ if not job_dependency[1]: return DependencyActions.OK deps_value = [] with get_db_file_handle(find_db_file()) as db_file: db = get_db(db_file) for dep in job_dependency[1]: internal_deps_ok = True for dep_type, job_idx in dep.items(): if dep_type != "afterok": raise NotImplementedError("Dependency type {} is not implemented.".format(dep_type)) dep_state = db[job_idx if isinstance(job_idx, list) else [job_idx]]["STATE"] if any(dep_state == JobState.FAILED) or any(dep_state == JobState.CANCELLED): return DependencyActions.NEVER internal_deps_ok &= all(dep_state == JobState.COMPLETED) deps_value.append(internal_deps_ok) combined_dep = all(deps_value) if job_dependency[0] else any(deps_value) return DependencyActions.OK if combined_dep else DependencyActions.WAIT
[docs] def launch_job( cmd: str, stdout: str, stderr: str, job_idx: int, job_dependency: List[Tuple[bool, Dict[str, List[int]]]] ): """Runs `cmd` as a detached subprocess. The return code of `cmd` is retrieved by the detached subprocess and updated in the DB. Parameters ---------- cmd : str Command to run, similar to the content of the --wrap= argument of sbatch stdout : str Path to a file where the output of `cmd` will be piped stderr : str Path to a file where the errors of `cmd` will be piped job_idx : int Index of the job in the DB job_dependency : List[Tuple[bool, Dict[str, List[int]]]] Job dependencies. Each tuple in the list corresponds to a dependency specification. See `mockslurm.mock_sbatch.parse_dependency` """ logging.debug( "lauch_job with args cmd: {}, stdout: {}, stderr: {}, job_idx: {}, job_dependency: {}".format( cmd, stdout, stderr, job_idx, job_dependency ) ) # Wait for dependencies to be ready dependency_check = check_dependency(job_dependency) if dependency_check == DependencyActions.WAIT: logging.debug( "Job {} dependency {} 1st check is WAIT, updating db Reason DEPENDENCY".format(job_idx, job_dependency) ) with get_db_file_handle(find_db_file()) as db_file: db = get_db(db_file) update_db_value(db_file, job_idx, key="REASON", value=JobReason.Dependency) while dependency_check == DependencyActions.WAIT: time.sleep(0.25) logging.debug("Job {} dependency {} check remains WAIT, sleeping".format(job_idx, job_dependency)) dependency_check = check_dependency(job_dependency) # If not ok: do not start job and mark its state as FAILED if dependency_check != DependencyActions.OK: logging.debug( "Job {} dependency check is {}, updating DB with REASON DependencyNeverSatisfied and STATE FAILED".format( job_idx, dependency_check ) ) with get_db_file_handle(find_db_file()) as db_file: db = get_db(db_file) update_db_value(db_file, job_idx, key="REASON", value=JobReason.DependencyNeverSatisfied) update_db_value(db_file, job_idx, key="STATE", value=JobState.FAILED) else: logging.debug("Job {} dependency check is OK, checking job STATE".format(job_idx)) # lock the DB here, so STATE can not change before we start the process with get_db_file_handle(find_db_file()) as db_file: db = get_db(db_file) if db[job_idx]["STATE"] == JobState.PENDING: logging.debug("Job {} STATE is still pending, starting job".format(job_idx)) try: stdout_f = open(stdout, "a") stderr_f = sp.STDOUT if stderr == stdout else open(stderr, "a") # Can not use shlex here to split command, otherwise we would split bash commands at each # space so we use a single string with shell=True p = sp.Popen(cmd, stdout=stdout_f, stderr=stderr_f, start_new_session=True, shell=True) logging.debug( "Job {} started with PID {}, cmd {}, stdout {}, stderr {}".format( job_idx, p.pid, cmd, stdout_f, stderr_f ) ) update_db_value(db_file, job_idx, key="PID", value=p.pid) update_db_value(db_file, job_idx, key="STATE", value=JobState.RUNNING) update_db_value(db_file, job_idx, key="REASON", value=JobReason.NOREASON) logging.debug( "Job {} updated db with PID {}, STATE RUNNING, REASON NOREASON".format(job_idx, p.pid) ) except: logging.debug("Job {} failed to start. Updating DB with STATE FAILED and REASON JobLaunchFailure") update_db_value(db_file, job_idx, key="STATE", value=JobState.FAILED) update_db_value(db_file, job_idx, key="REASON", value=JobReason.JobLaunchFailure) logging.debug("Job {} starting error", exc_info=True) raise else: logging.debug( "Job {} STATE is {}. Do not start start job and exit".format(job_idx, db[job_idx]["STATE"]) ) # Process is not pending anymore, it must have been killed already, so we won't start it return logging.debug("Job {} is running, waiting its completion".format(job_idx)) # wait for process to be done exit_code = p.wait() logging.debug("Job {} completed with exit code {}".format(job_idx, exit_code)) # closing stdout and stderr file of job stdout_f.close() if stderr != stdout: stderr_f.close() logging.debug("Job {} closed stdout {} and stderr {}".format(job_idx, stdout, stderr)) # Update job state in db with get_db_file_handle(find_db_file()) as db_file: job_state = JobState.COMPLETED if exit_code == 0 else JobState.FAILED job_reason = JobReason.NOREASON if exit_code == 0 else JobReason.NonZeroExitCode update_db_value(db_file, job_idx, key="EXIT_CODE", value=exit_code) update_db_value(db_file, job_idx, key="STATE", value=job_state) update_db_value(db_file, job_idx, key="REASON", value=job_reason) logging.debug( "Job {} Updated db with EXIT_CODE {}, STATE {}, REASON {}".format( job_idx, exit_code, JobState(job_state).name, JobReason(job_reason).name ) )
[docs] def main(): parser = argparse.ArgumentParser( description="Slurm sbtach mock.", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument("--account", "-A", type=str, dest="ACCOUNT", help="user account", required=False) parser.add_argument( "--dependency", "-d", type=str, default="", dest="dependency", help="Defer the start of this job until the specified dependencies have been satisfied", ) parser.add_argument( "--error", "-e", type=str, dest="stderr", help="error file of slurm job", default="slurm-%j.out" ) parser.add_argument("--job-name", "-J", type=str, dest="NAME", help="job name", default="wrap") parser.add_argument( "--nodelist", "-w", type=str, dest="NODELIST", help="Request a specific list of hosts", nargs="*", ) parser.add_argument( "--output", "-o", type=str, dest="stdout", help="output file of slurm job", default="slurm-%j.out" ) parser.add_argument("--partition", "-p", type=str, dest="PARTITION", help="job partition", required=False) parser.add_argument( "--parsable", action="store_true", dest="parsable", help="Outputs only the job id number. Errors will still be displayed", ) parser.add_argument("--reservation", type=str, dest="RESERVATION", help="job reservation", required=False) parser.add_argument( "--wrap", type=str, dest="CMD", help="Command to be executed", required=True, ) parser.add_argument( "--mock_sbatch_debug", action="store_true", dest="debug_mode", help="If provided, logs all actions of the sbatch mock in a mock_sbatch.log file", default=False, ) args, _ = parser.parse_known_args() defined_args = {arg: value for arg, value in vars(args).items() if value is not None} defined_args.pop("stdout") defined_args.pop("stderr") defined_args.pop("parsable") defined_args.pop("dependency") defined_args.pop("debug_mode") # TODO: raise error if arguments to sbatch aren't valid (reservation, partition, nodelist ? etc.) if args.debug_mode: logging.basicConfig( filename="mock_sbatch.log", level=logging.DEBUG, format="%(asctime)s %(levelname)s mock_sbatch %(pathname)s:%(lineno)s:%(funcName)s %(message)s", ) logging.debug( "\n".join( ["mock_sbatch called with args:"] + ["--{}: {}".format(arg, arg_value) for arg, arg_value in vars(args).items()] ) ) with get_db_file_handle(find_db_file()) as db_file: job_idx = append_job(db_file, **defined_args) logging.debug("Appended job {} in DB with args {}".format(job_idx, defined_args)) # if parsable is set: print jobID if args.parsable: print(job_idx, end="", flush=True) logging.debug("Printing job ID {} to stdout".format(job_idx)) # In order to create a process to run the sbatch command and exit sbatch without killing the child process or # making it a zombie, we need to use the double fork technique used to create deamons: # see https://stackoverflow.com/questions/473620/how-do-you-create-a-daemon-in-python # https://stackoverflow.com/questions/881388/what-is-the-reason-for-performing-a-doube-fork-when-creating-a-daemon # Double forking is used to create a new process B (that is attached to process group of A, its parent), then # create a new process C and dettach it from A. The new process can then run and be cleaned up by process 1 # instead of A. # There are a number of other things to consider to completely detach a process, such as closing all file handles, # here we only do the minimum for the mock to work. # See https://pagure.io/python-daemon/blob/main/f/src/daemon/daemon.py for a reference implementation of a daemon # in python (old one, there is a pep associated that was never merged due to lack of interest) logging.debug("Job {} Parent process will fork child".format(job_idx)) # fork a first time pid_1st_fork = os.fork() if pid_1st_fork == 0: # we are in the child process logging.debug("Job {} Child process, creating new process session and becoming group leader".format(job_idx)) # create new process session for the child # the 2nd fork process will be in this session, detached from the initial process os.setsid() logging.debug("Job {} Child process will fork grandchild".format(job_idx)) # second fork pid_2nd_fork = os.fork() if pid_2nd_fork == 0: # we are in the grandchild: run the sbatch job, then update db logging.debug("Job {} Grandchild process re-derecting stdin, stdout and stderr".format(job_idx)) # redirect stdin, stdout and stderr to /dev/null (so we separate from child's file handles) # (child's file handles are the same than the parent child handles...) os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdin.fileno()) os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdout.fileno()) os.dup2(os.open(os.devnull, os.O_RDWR), sys.stderr.fileno()) logging.debug("Job {} Grandchild process calls launch_job".format(job_idx)) launch_job( args.CMD, args.stdout.replace("%j", str(job_idx).strip()), args.stderr.replace("%j", str(job_idx).strip()), job_idx, parse_dependency(args.dependency), ) # exit without clean up. # it is REQUIRED to have the grandchild exit in this way if it is tested with pytest, otherwise pytest # will run the program several times see https://github.com/pytest-dev/pytest/issues/12028 # Side effect: the stdio buffers will not be flushed (detached process can't output to stdio anyway) os._exit(0) else: # exit the child without cleaning up file handlers or flushing buffers, because those are shared with parent logging.debug("Job {} Child proces exiting without cleaning up.".format(job_idx)) os._exit(0) else: os.wait() # this waits for the child (not the grandchild!) to finish to clean it up logging.debug("Job {} Parent process exiting after waiting on child process.".format(job_idx))
if __name__ == "__main__": main()