Coverage for mockslurm/mock_sbatch.py: 64%
105 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-21 00:38 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-21 00:38 +0000
1#!/usr/bin/env python
2"""Implement a mock of the sbatch command of slurm.
4The implementation mimics the sbatch API, while using subprocess to
5start the processes on the host machine without slurm daemons.
7Double fork is used to detach the subprocess, allowing the sbatch
8mock to return while the created process is still running.
9"""
11import argparse
12import os
13import subprocess as sp
14import time
15from enum import Enum
16from multiprocessing import Process
17from typing import Dict, List, Tuple
19from mockslurm.process_db import (
20 JobReason,
21 JobState,
22 append_job,
23 find_db_file,
24 get_db,
25 get_db_file_handle,
26 update_db_value,
27)
30class DependencyActions(Enum):
31 OK = 1
32 WAIT = 2
33 NEVER = 3
36def parse_dependency(job_dependency: str) -> Tuple[bool, List[Dict[str, List[int]]]]:
37 """Parse a job's dependency specification string in sbatch format
39 Parameters
40 ----------
41 job_dependency : str
42 sbatch job's dependency string. Examples: afterok:20:21,afterany:23
44 Returns
45 -------
46 Tuple[bool, List[Dict[str, List[int]]]]
47 Tuple 1st element is whether all dependencies must be satisfied for a job to start (True, "," in slurm str),
48 or if a single dependency is enough (False, "?" in slurm str). When a single dependency is present: True.
49 The list values are dictionaries with a single key been the dependency type ("afterok", "afterany", etc.)
50 and the values been the job IDs for these dependency
51 Example
52 -------
53 >>> parse_dependency("afterok:20:21?afterany:23")
54 (True, [{"afterok": [20, 21]}, {"afterany": [23]}])
55 """
56 if "," in job_dependency and "?" in job_dependency:
57 raise ValueError('Multiple separator "?", ","in job dependency not supported.')
58 separator = "?" if "?" in job_dependency else "," # default is "," if only 1 dep
59 dependencies = []
60 while job_dependency:
61 if job_dependency[0] == separator:
62 job_dependency = job_dependency[1:]
64 dep_str = job_dependency.split(separator, maxsplit=1)[0]
65 job_dependency = job_dependency[len(dep_str) :]
66 split = dep_str.split(":")
67 dependencies.append({split[0]: [int(v) for v in split[1:]]})
69 return (separator == ",", dependencies)
72def check_dependency(job_dependency: List[Tuple[bool, Dict[str, List[int]]]]) -> str:
73 """Check if a job's dependency are satisfied
75 Parameters
76 ----------
77 job_dependency : Tuple[bool, List[Dict[str, List[int]]]]
78 Job dependencies. Each tuple in the list corresponds to a dependency specification.
79 See `mockslurm.mock_sbatch.parse_dependency`
81 Warning
82 -------
83 This function only supports the "afterok" dependency specification, anything else causes
84 a `NotImplementedError` to be raised.
86 Raise
87 -----
88 NotImplementedError
89 If the dependency type is not "afterok"
91 Returns
92 -------
93 DependencyActions
94 The action to follow based on the dependency evaluation, for instance to wait, to start job, etc.
95 """
96 if not job_dependency[1]:
97 return DependencyActions.OK
99 deps_value = []
100 with get_db_file_handle(find_db_file()) as db_file:
101 db = get_db(db_file)
102 for dep in job_dependency[1]:
103 internal_deps_ok = True
104 for dep_type, job_idx in dep.items():
105 if dep_type != "afterok":
106 raise NotImplementedError("Dependency type {} is not implemented.".format(dep_type))
107 dep_state = db[job_idx if isinstance(job_idx, list) else [job_idx]]["STATE"]
108 if any(dep_state == JobState.FAILED):
109 return DependencyActions.NEVER
110 internal_deps_ok &= all(dep_state == JobState.COMPLETED)
112 deps_value.append(internal_deps_ok)
114 combined_dep = all(deps_value) if job_dependency[0] else any(deps_value)
115 return DependencyActions.OK if combined_dep else DependencyActions.WAIT
118def launch_job(
119 cmd: str, stdout: str, stderr: str, job_idx: int, job_dependency: List[Tuple[bool, Dict[str, List[int]]]]
120):
121 """Runs `cmd` as a detached subprocess.
123 The return code of `cmd` is retrieved by the detached subprocess and updated in the DB.
125 Parameters
126 ----------
127 cmd : str
128 Command to run, similar to the content of the --wrap= argument of sbatch
129 stdout : str
130 Path to a file where the output of `cmd` will be piped
131 stderr : str
132 Path to a file where the errors of `cmd` will be piped
133 job_idx : int
134 Index of the job in the DB
135 job_dependency : List[Tuple[bool, Dict[str, List[int]]]]
136 Job dependencies. Each tuple in the list corresponds to a dependency specification.
137 See `mockslurm.mock_sbatch.parse_dependency`
138 """
139 # In order for this function to keep running while initial process exists
140 # we need to "double-fork", to get out of the parent's process group
141 # to do this: - call fork.
142 # - If we stayed in the parent process, we get 0 -> return immediately
143 # - If we are the child (fork value == child PID != 0): do stuff
144 # https://stackoverflow.com/questions/5631624/how-to-get-exit-code-when-using-python-subprocess-communicate-method
145 # WARNING: we need to do this before starting the process, otherwise
146 # return code of executed process is invalid (always 0)
147 if os.fork() != 0:
148 return
149 # Wait for dependencies to be ready
150 dependency_check = check_dependency(job_dependency)
151 if dependency_check == DependencyActions.WAIT:
152 with get_db_file_handle(find_db_file()) as db_file:
153 db = get_db(db_file)
154 update_db_value(db_file, job_idx, key="REASON", value=JobReason.Dependency)
155 while dependency_check == DependencyActions.WAIT:
156 time.sleep(0.25)
157 dependency_check = check_dependency(job_dependency)
159 # If not ok: do not start job and mark its state as
160 if dependency_check != DependencyActions.OK:
161 with get_db_file_handle(find_db_file()) as db_file:
162 db = get_db(db_file)
163 update_db_value(db_file, job_idx, key="REASON", value=JobReason.DependencyNeverSatisfied)
164 update_db_value(db_file, job_idx, key="STATE", value=JobState.FAILED)
165 else:
166 # can't use stdout, stderr of subpr
167 # ocess since we detach process and exit parent
168 # (so file descriptors would be closed) so use shell redirection
169 # Can not use shlex here to split command, otherwise output redirection is treated as
170 # another argument (surrounded by "") and treated as another argument by the command
171 # so we use a single string with shell=True
172 cmd += " 1>>" + stdout + " 2>>" + stderr # + "; echo $?"
173 # lock the DB here, so STATE can not change before we start the process
174 with get_db_file_handle(find_db_file()) as db_file:
175 db = get_db(db_file)
176 if db[job_idx]["STATE"] == JobState.PENDING:
177 try:
178 p = sp.Popen(cmd, start_new_session=True, shell=True)
179 update_db_value(db_file, job_idx, key="PID", value=p.pid)
180 update_db_value(db_file, job_idx, key="STATE", value=JobState.RUNNING)
181 update_db_value(db_file, job_idx, key="REASON", value=JobReason.NOREASON)
182 except:
183 update_db_value(db_file, job_idx, key="STATE", value=JobState.FAILED)
184 update_db_value(db_file, job_idx, key="REASON", value=JobReason.JobLaunchFailure)
185 raise
186 else:
187 # Process is not pending anymore, it must have been killed already, so we won't start it
188 return
190 # wait for process to be done
191 exit_code = p.wait()
192 with get_db_file_handle(find_db_file()) as db_file:
193 update_db_value(db_file, job_idx, key="EXIT_CODE", value=exit_code)
194 update_db_value(
195 db_file, job_idx, key="STATE", value=JobState.COMPLETED if exit_code == 0 else JobState.FAILED
196 )
197 update_db_value(
198 db_file, job_idx, key="REASON", value=JobReason.NOREASON if exit == 0 else JobReason.NonZeroExitCode
199 )
202def main():
203 parser = argparse.ArgumentParser(
204 description="Slurm sbtach mock.", formatter_class=argparse.ArgumentDefaultsHelpFormatter
205 )
206 parser.add_argument("--account", "-A", type=str, dest="ACCOUNT", help="user account", required=False)
207 parser.add_argument(
208 "--dependency",
209 "-d",
210 type=str,
211 default="",
212 dest="dependency",
213 help="Defer the start of this job until the specified dependencies have been satisfied",
214 )
215 parser.add_argument(
216 "--error", "-e", type=str, dest="stderr", help="error file of slurm job", default="slurm-%j.out"
217 )
218 parser.add_argument("--job-name", "-J", type=str, dest="NAME", help="job name", default="wrap")
219 parser.add_argument(
220 "--nodelist",
221 "-w",
222 type=str,
223 dest="NODELIST",
224 help="Request a specific list of hosts",
225 nargs="*",
226 )
227 parser.add_argument(
228 "--output", "-o", type=str, dest="stdout", help="output file of slurm job", default="slurm-%j.out"
229 )
230 parser.add_argument("--partition", "-p", type=str, dest="PARTITION", help="job partition", required=False)
231 parser.add_argument(
232 "--parsable",
233 action="store_true",
234 dest="parsable",
235 help="Outputs only the job id number. Errors will still be displayed",
236 )
237 parser.add_argument("--reservation", type=str, dest="RESERVATION", help="job reservation", required=False)
238 parser.add_argument(
239 "--wrap",
240 type=str,
241 dest="CMD",
242 help="Command to be executed",
243 required=True,
244 )
246 args, _ = parser.parse_known_args()
247 defined_args = {arg: value for arg, value in vars(args).items() if value is not None}
248 defined_args.pop("stdout")
249 defined_args.pop("stderr")
250 defined_args.pop("parsable")
251 defined_args.pop("dependency")
253 with get_db_file_handle(find_db_file()) as db_file:
254 job_idx = append_job(db_file, **defined_args)
256 # TODO: raise error if arguments to sbatch aren't valid (reservation, partition, nodelist ? etc.)
258 detach_p = Process(
259 target=lambda: launch_job(
260 args.CMD,
261 args.stdout.replace("%j", str(job_idx)).strip(),
262 args.stderr.replace("%j", str(job_idx)).strip(),
263 job_idx,
264 parse_dependency(args.dependency),
265 )
266 )
267 detach_p.start()
268 detach_p.join() # wait for the detach process that stayed in the parent process group
270 # if parsable is set: print jobID
271 if args.parsable: 271 ↛ 272line 271 didn't jump to line 272 because the condition on line 271 was never true
272 print(job_idx, end="")
275if __name__ == "__main__": 275 ↛ 276line 275 didn't jump to line 276 because the condition on line 275 was never true
276 main()