Source code for mockslurm.mock_sbatch

#!/usr/bin/env python
"""Implement a mock of the sbatch command of slurm.

The implementation mimics the sbatch API, while using subprocess to 
start the processes on the host machine without slurm daemons.

Double fork is used to detach the subprocess, allowing the sbatch
mock to return while the created process is still running.
"""

import argparse
import os
import subprocess as sp
import time
from enum import Enum
from multiprocessing import Process
from typing import Dict, List, Tuple

from mockslurm.process_db import (
    JobReason,
    JobState,
    append_job,
    find_db_file,
    get_db,
    get_db_file_handle,
    update_db_value,
)


[docs] class DependencyActions(Enum): OK = 1 WAIT = 2 NEVER = 3
[docs] def parse_dependency(job_dependency: str) -> Tuple[bool, List[Dict[str, List[int]]]]: """Parse a job's dependency specification string in sbatch format Parameters ---------- job_dependency : str sbatch job's dependency string. Examples: afterok:20:21,afterany:23 Returns ------- Tuple[bool, List[Dict[str, List[int]]]] Tuple 1st element is whether all dependencies must be satisfied for a job to start (True, "," in slurm str), or if a single dependency is enough (False, "?" in slurm str). When a single dependency is present: True. The list values are dictionaries with a single key been the dependency type ("afterok", "afterany", etc.) and the values been the job IDs for these dependency Example ------- >>> parse_dependency("afterok:20:21?afterany:23") (True, [{"afterok": [20, 21]}, {"afterany": [23]}]) """ if "," in job_dependency and "?" in job_dependency: raise ValueError('Multiple separator "?", ","in job dependency not supported.') separator = "?" if "?" in job_dependency else "," # default is "," if only 1 dep dependencies = [] while job_dependency: if job_dependency[0] == separator: job_dependency = job_dependency[1:] dep_str = job_dependency.split(separator, maxsplit=1)[0] job_dependency = job_dependency[len(dep_str) :] split = dep_str.split(":") dependencies.append({split[0]: [int(v) for v in split[1:]]}) return (separator == ",", dependencies)
[docs] def check_dependency(job_dependency: List[Tuple[bool, Dict[str, List[int]]]]) -> str: """Check if a job's dependency are satisfied Parameters ---------- job_dependency : Tuple[bool, List[Dict[str, List[int]]]] Job dependencies. Each tuple in the list corresponds to a dependency specification. See `mockslurm.mock_sbatch.parse_dependency` Warning ------- This function only supports the "afterok" dependency specification, anything else causes a `NotImplementedError` to be raised. Raise ----- NotImplementedError If the dependency type is not "afterok" Returns ------- DependencyActions The action to follow based on the dependency evaluation, for instance to wait, to start job, etc. """ if not job_dependency[1]: return DependencyActions.OK deps_value = [] with get_db_file_handle(find_db_file()) as db_file: db = get_db(db_file) for dep in job_dependency[1]: internal_deps_ok = True for dep_type, job_idx in dep.items(): if dep_type != "afterok": raise NotImplementedError("Dependency type {} is not implemented.".format(dep_type)) dep_state = db[job_idx if isinstance(job_idx, list) else [job_idx]]["STATE"] if any(dep_state == JobState.FAILED): return DependencyActions.NEVER internal_deps_ok &= all(dep_state == JobState.COMPLETED) deps_value.append(internal_deps_ok) combined_dep = all(deps_value) if job_dependency[0] else any(deps_value) return DependencyActions.OK if combined_dep else DependencyActions.WAIT
[docs] def launch_job( cmd: str, stdout: str, stderr: str, job_idx: int, job_dependency: List[Tuple[bool, Dict[str, List[int]]]] ): """Runs `cmd` as a detached subprocess. The return code of `cmd` is retrieved by the detached subprocess and updated in the DB. Parameters ---------- cmd : str Command to run, similar to the content of the --wrap= argument of sbatch stdout : str Path to a file where the output of `cmd` will be piped stderr : str Path to a file where the errors of `cmd` will be piped job_idx : int Index of the job in the DB job_dependency : List[Tuple[bool, Dict[str, List[int]]]] Job dependencies. Each tuple in the list corresponds to a dependency specification. See `mockslurm.mock_sbatch.parse_dependency` """ # In order for this function to keep running while initial process exists # we need to "double-fork", to get out of the parent's process group # to do this: - call fork. # - If we stayed in the parent process, we get 0 -> return immediately # - If we are the child (fork value == child PID != 0): do stuff # https://stackoverflow.com/questions/5631624/how-to-get-exit-code-when-using-python-subprocess-communicate-method # WARNING: we need to do this before starting the process, otherwise # return code of executed process is invalid (always 0) if os.fork() != 0: return # Wait for dependencies to be ready dependency_check = check_dependency(job_dependency) if dependency_check == DependencyActions.WAIT: with get_db_file_handle(find_db_file()) as db_file: db = get_db(db_file) update_db_value(db_file, job_idx, key="REASON", value=JobReason.Dependency) while dependency_check == DependencyActions.WAIT: time.sleep(0.25) dependency_check = check_dependency(job_dependency) # If not ok: do not start job and mark its state as if dependency_check != DependencyActions.OK: with get_db_file_handle(find_db_file()) as db_file: db = get_db(db_file) update_db_value(db_file, job_idx, key="REASON", value=JobReason.DependencyNeverSatisfied) update_db_value(db_file, job_idx, key="STATE", value=JobState.FAILED) else: # can't use stdout, stderr of subpr # ocess since we detach process and exit parent # (so file descriptors would be closed) so use shell redirection # Can not use shlex here to split command, otherwise output redirection is treated as # another argument (surrounded by "") and treated as another argument by the command # so we use a single string with shell=True cmd += " 1>>" + stdout + " 2>>" + stderr # + "; echo $?" # lock the DB here, so STATE can not change before we start the process with get_db_file_handle(find_db_file()) as db_file: db = get_db(db_file) if db[job_idx]["STATE"] == JobState.PENDING: try: p = sp.Popen(cmd, start_new_session=True, shell=True) update_db_value(db_file, job_idx, key="PID", value=p.pid) update_db_value(db_file, job_idx, key="STATE", value=JobState.RUNNING) update_db_value(db_file, job_idx, key="REASON", value=JobReason.NOREASON) except: update_db_value(db_file, job_idx, key="STATE", value=JobState.FAILED) update_db_value(db_file, job_idx, key="REASON", value=JobReason.JobLaunchFailure) raise else: # Process is not pending anymore, it must have been killed already, so we won't start it return # wait for process to be done exit_code = p.wait() with get_db_file_handle(find_db_file()) as db_file: update_db_value(db_file, job_idx, key="EXIT_CODE", value=exit_code) update_db_value( db_file, job_idx, key="STATE", value=JobState.COMPLETED if exit_code == 0 else JobState.FAILED ) update_db_value( db_file, job_idx, key="REASON", value=JobReason.NOREASON if exit == 0 else JobReason.NonZeroExitCode )
[docs] def main(): parser = argparse.ArgumentParser( description="Slurm sbtach mock.", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument("--account", "-A", type=str, dest="ACCOUNT", help="user account", required=False) parser.add_argument( "--dependency", "-d", type=str, default="", dest="dependency", help="Defer the start of this job until the specified dependencies have been satisfied", ) parser.add_argument( "--error", "-e", type=str, dest="stderr", help="error file of slurm job", default="slurm-%j.out" ) parser.add_argument("--job-name", "-J", type=str, dest="NAME", help="job name", default="wrap") parser.add_argument( "--nodelist", "-w", type=str, dest="NODELIST", help="Request a specific list of hosts", nargs="*", ) parser.add_argument( "--output", "-o", type=str, dest="stdout", help="output file of slurm job", default="slurm-%j.out" ) parser.add_argument("--partition", "-p", type=str, dest="PARTITION", help="job partition", required=False) parser.add_argument( "--parsable", action="store_true", dest="parsable", help="Outputs only the job id number. Errors will still be displayed", ) parser.add_argument("--reservation", type=str, dest="RESERVATION", help="job reservation", required=False) parser.add_argument( "--wrap", type=str, dest="CMD", help="Command to be executed", required=True, ) args, _ = parser.parse_known_args() defined_args = {arg: value for arg, value in vars(args).items() if value is not None} defined_args.pop("stdout") defined_args.pop("stderr") defined_args.pop("parsable") defined_args.pop("dependency") with get_db_file_handle(find_db_file()) as db_file: job_idx = append_job(db_file, **defined_args) # TODO: raise error if arguments to sbatch aren't valid (reservation, partition, nodelist ? etc.) detach_p = Process( target=lambda: launch_job( args.CMD, args.stdout.replace("%j", str(job_idx)).strip(), args.stderr.replace("%j", str(job_idx)).strip(), job_idx, parse_dependency(args.dependency), ) ) detach_p.start() detach_p.join() # wait for the detach process that stayed in the parent process group # if parsable is set: print jobID if args.parsable: print(job_idx, end="")
if __name__ == "__main__": main()