Coverage for mockslurm/mock_sbatch.py: 53%

148 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-17 10:19 +0000

1#!/usr/bin/env python 

2"""Implement a mock of the sbatch command of slurm. 

3 

4The implementation mimics the sbatch API, while using subprocess to 

5start the processes on the host machine without slurm daemons. 

6 

7Double fork is used to detach the subprocess, allowing the sbatch 

8mock to return while the created process is still running. 

9""" 

10 

11import argparse 

12import os 

13import subprocess as sp 

14import time 

15from enum import Enum 

16from typing import Dict, List, Tuple 

17import sys 

18import logging 

19 

20from mockslurm.process_db import ( 

21 JobReason, 

22 JobState, 

23 append_job, 

24 find_db_file, 

25 get_db, 

26 get_db_file_handle, 

27 update_db_value, 

28) 

29 

30 

31class DependencyActions(Enum): 

32 OK = 1 

33 WAIT = 2 

34 NEVER = 3 

35 

36 

37def parse_dependency(job_dependency: str) -> Tuple[bool, List[Dict[str, List[int]]]]: 

38 """Parse a job's dependency specification string in sbatch format 

39 

40 Parameters 

41 ---------- 

42 job_dependency : str 

43 sbatch job's dependency string. Examples: afterok:20:21,afterany:23 

44 

45 Returns 

46 ------- 

47 Tuple[bool, List[Dict[str, List[int]]]] 

48 Tuple 1st element is whether all dependencies must be satisfied for a job to start (True, "," in slurm str), 

49 or if a single dependency is enough (False, "?" in slurm str). When a single dependency is present: True. 

50 The list values are dictionaries with a single key been the dependency type ("afterok", "afterany", etc.) 

51 and the values been the job IDs for these dependency 

52 Example 

53 ------- 

54 >>> parse_dependency("afterok:20:21?afterany:23") 

55 (True, [{"afterok": [20, 21]}, {"afterany": [23]}]) 

56 """ 

57 if "," in job_dependency and "?" in job_dependency: 

58 raise ValueError('Multiple separator "?", ","in job dependency not supported.') 

59 separator = "?" if "?" in job_dependency else "," # default is "," if only 1 dep 

60 dependencies = [] 

61 while job_dependency: 

62 if job_dependency[0] == separator: 

63 job_dependency = job_dependency[1:] 

64 

65 dep_str = job_dependency.split(separator, maxsplit=1)[0] 

66 job_dependency = job_dependency[len(dep_str) :] 

67 split = dep_str.split(":") 

68 dependencies.append({split[0]: [int(v) for v in split[1:]]}) 

69 

70 return (separator == ",", dependencies) 

71 

72 

73def check_dependency(job_dependency: List[Tuple[bool, Dict[str, List[int]]]]) -> str: 

74 """Check if a job's dependency are satisfied 

75 

76 Parameters 

77 ---------- 

78 job_dependency : Tuple[bool, List[Dict[str, List[int]]]] 

79 Job dependencies. Each tuple in the list corresponds to a dependency specification. 

80 See `mockslurm.mock_sbatch.parse_dependency` 

81 

82 Warning 

83 ------- 

84 This function only supports the "afterok" dependency specification, anything else causes 

85 a `NotImplementedError` to be raised. 

86 

87 Raise 

88 ----- 

89 NotImplementedError 

90 If the dependency type is not "afterok" 

91 

92 Returns 

93 ------- 

94 DependencyActions 

95 The action to follow based on the dependency evaluation, for instance to wait, to start job, etc. 

96 """ 

97 if not job_dependency[1]: 

98 return DependencyActions.OK 

99 

100 deps_value = [] 

101 with get_db_file_handle(find_db_file()) as db_file: 

102 db = get_db(db_file) 

103 for dep in job_dependency[1]: 

104 internal_deps_ok = True 

105 for dep_type, job_idx in dep.items(): 

106 if dep_type != "afterok": 

107 raise NotImplementedError("Dependency type {} is not implemented.".format(dep_type)) 

108 dep_state = db[job_idx if isinstance(job_idx, list) else [job_idx]]["STATE"] 

109 if any(dep_state == JobState.FAILED) or any(dep_state == JobState.CANCELLED): 

110 return DependencyActions.NEVER 

111 internal_deps_ok &= all(dep_state == JobState.COMPLETED) 

112 

113 deps_value.append(internal_deps_ok) 

114 

115 combined_dep = all(deps_value) if job_dependency[0] else any(deps_value) 

116 return DependencyActions.OK if combined_dep else DependencyActions.WAIT 

117 

118 

119def launch_job( 

120 cmd: str, stdout: str, stderr: str, job_idx: int, job_dependency: List[Tuple[bool, Dict[str, List[int]]]] 

121): 

122 """Runs `cmd` as a detached subprocess. 

123 

124 The return code of `cmd` is retrieved by the detached subprocess and updated in the DB. 

125 

126 Parameters 

127 ---------- 

128 cmd : str 

129 Command to run, similar to the content of the --wrap= argument of sbatch 

130 stdout : str 

131 Path to a file where the output of `cmd` will be piped 

132 stderr : str 

133 Path to a file where the errors of `cmd` will be piped 

134 job_idx : int 

135 Index of the job in the DB 

136 job_dependency : List[Tuple[bool, Dict[str, List[int]]]] 

137 Job dependencies. Each tuple in the list corresponds to a dependency specification. 

138 See `mockslurm.mock_sbatch.parse_dependency` 

139 """ 

140 logging.debug( 

141 "lauch_job with args cmd: {}, stdout: {}, stderr: {}, job_idx: {}, job_dependency: {}".format( 

142 cmd, stdout, stderr, job_idx, job_dependency 

143 ) 

144 ) 

145 # Wait for dependencies to be ready 

146 dependency_check = check_dependency(job_dependency) 

147 if dependency_check == DependencyActions.WAIT: 

148 logging.debug( 

149 "Job {} dependency {} 1st check is WAIT, updating db Reason DEPENDENCY".format(job_idx, job_dependency) 

150 ) 

151 with get_db_file_handle(find_db_file()) as db_file: 

152 db = get_db(db_file) 

153 update_db_value(db_file, job_idx, key="REASON", value=JobReason.Dependency) 

154 while dependency_check == DependencyActions.WAIT: 

155 time.sleep(0.25) 

156 logging.debug("Job {} dependency {} check remains WAIT, sleeping".format(job_idx, job_dependency)) 

157 dependency_check = check_dependency(job_dependency) 

158 

159 # If not ok: do not start job and mark its state as FAILED 

160 if dependency_check != DependencyActions.OK: 

161 logging.debug( 

162 "Job {} dependency check is {}, updating DB with REASON DependencyNeverSatisfied and STATE FAILED".format( 

163 job_idx, dependency_check 

164 ) 

165 ) 

166 with get_db_file_handle(find_db_file()) as db_file: 

167 db = get_db(db_file) 

168 update_db_value(db_file, job_idx, key="REASON", value=JobReason.DependencyNeverSatisfied) 

169 update_db_value(db_file, job_idx, key="STATE", value=JobState.FAILED) 

170 else: 

171 logging.debug("Job {} dependency check is OK, checking job STATE".format(job_idx)) 

172 # lock the DB here, so STATE can not change before we start the process 

173 with get_db_file_handle(find_db_file()) as db_file: 

174 db = get_db(db_file) 

175 if db[job_idx]["STATE"] == JobState.PENDING: 

176 logging.debug("Job {} STATE is still pending, starting job".format(job_idx)) 

177 try: 

178 stdout_f = open(stdout, "a") 

179 stderr_f = sp.STDOUT if stderr == stdout else open(stderr, "a") 

180 # Can not use shlex here to split command, otherwise we would split bash commands at each 

181 # space so we use a single string with shell=True 

182 p = sp.Popen(cmd, stdout=stdout_f, stderr=stderr_f, start_new_session=True, shell=True) 

183 logging.debug( 

184 "Job {} started with PID {}, cmd {}, stdout {}, stderr {}".format( 

185 job_idx, p.pid, cmd, stdout_f, stderr_f 

186 ) 

187 ) 

188 update_db_value(db_file, job_idx, key="PID", value=p.pid) 

189 update_db_value(db_file, job_idx, key="STATE", value=JobState.RUNNING) 

190 update_db_value(db_file, job_idx, key="REASON", value=JobReason.NOREASON) 

191 logging.debug( 

192 "Job {} updated db with PID {}, STATE RUNNING, REASON NOREASON".format(job_idx, p.pid) 

193 ) 

194 except: 

195 logging.debug("Job {} failed to start. Updating DB with STATE FAILED and REASON JobLaunchFailure") 

196 update_db_value(db_file, job_idx, key="STATE", value=JobState.FAILED) 

197 update_db_value(db_file, job_idx, key="REASON", value=JobReason.JobLaunchFailure) 

198 logging.debug("Job {} starting error", exc_info=True) 

199 raise 

200 else: 

201 logging.debug( 

202 "Job {} STATE is {}. Do not start start job and exit".format(job_idx, db[job_idx]["STATE"]) 

203 ) 

204 # Process is not pending anymore, it must have been killed already, so we won't start it 

205 return 

206 

207 logging.debug("Job {} is running, waiting its completion".format(job_idx)) 

208 # wait for process to be done 

209 exit_code = p.wait() 

210 logging.debug("Job {} completed with exit code {}".format(job_idx, exit_code)) 

211 # closing stdout and stderr file of job 

212 stdout_f.close() 

213 if stderr != stdout: 

214 stderr_f.close() 

215 logging.debug("Job {} closed stdout {} and stderr {}".format(job_idx, stdout, stderr)) 

216 # Update job state in db 

217 with get_db_file_handle(find_db_file()) as db_file: 

218 job_state = JobState.COMPLETED if exit_code == 0 else JobState.FAILED 

219 job_reason = JobReason.NOREASON if exit_code == 0 else JobReason.NonZeroExitCode 

220 update_db_value(db_file, job_idx, key="EXIT_CODE", value=exit_code) 

221 update_db_value(db_file, job_idx, key="STATE", value=job_state) 

222 update_db_value(db_file, job_idx, key="REASON", value=job_reason) 

223 logging.debug( 

224 "Job {} Updated db with EXIT_CODE {}, STATE {}, REASON {}".format( 

225 job_idx, exit_code, JobState(job_state).name, JobReason(job_reason).name 

226 ) 

227 ) 

228 

229 

230def main(): 

231 parser = argparse.ArgumentParser( 

232 description="Slurm sbtach mock.", formatter_class=argparse.ArgumentDefaultsHelpFormatter 

233 ) 

234 parser.add_argument("--account", "-A", type=str, dest="ACCOUNT", help="user account", required=False) 

235 parser.add_argument( 

236 "--dependency", 

237 "-d", 

238 type=str, 

239 default="", 

240 dest="dependency", 

241 help="Defer the start of this job until the specified dependencies have been satisfied", 

242 ) 

243 parser.add_argument( 

244 "--error", "-e", type=str, dest="stderr", help="error file of slurm job", default="slurm-%j.out" 

245 ) 

246 parser.add_argument("--job-name", "-J", type=str, dest="NAME", help="job name", default="wrap") 

247 parser.add_argument( 

248 "--nodelist", 

249 "-w", 

250 type=str, 

251 dest="NODELIST", 

252 help="Request a specific list of hosts", 

253 nargs="*", 

254 ) 

255 parser.add_argument( 

256 "--output", "-o", type=str, dest="stdout", help="output file of slurm job", default="slurm-%j.out" 

257 ) 

258 parser.add_argument("--partition", "-p", type=str, dest="PARTITION", help="job partition", required=False) 

259 parser.add_argument( 

260 "--parsable", 

261 action="store_true", 

262 dest="parsable", 

263 help="Outputs only the job id number. Errors will still be displayed", 

264 ) 

265 parser.add_argument("--reservation", type=str, dest="RESERVATION", help="job reservation", required=False) 

266 parser.add_argument( 

267 "--wrap", 

268 type=str, 

269 dest="CMD", 

270 help="Command to be executed", 

271 required=True, 

272 ) 

273 parser.add_argument( 

274 "--mock_sbatch_debug", 

275 action="store_true", 

276 dest="debug_mode", 

277 help="If provided, logs all actions of the sbatch mock in a mock_sbatch.log file", 

278 default=False, 

279 ) 

280 

281 args, _ = parser.parse_known_args() 

282 defined_args = {arg: value for arg, value in vars(args).items() if value is not None} 

283 defined_args.pop("stdout") 

284 defined_args.pop("stderr") 

285 defined_args.pop("parsable") 

286 defined_args.pop("dependency") 

287 defined_args.pop("debug_mode") 

288 # TODO: raise error if arguments to sbatch aren't valid (reservation, partition, nodelist ? etc.) 

289 

290 if args.debug_mode: 290 ↛ 297line 290 didn't jump to line 297 because the condition on line 290 was always true

291 logging.basicConfig( 

292 filename="mock_sbatch.log", 

293 level=logging.DEBUG, 

294 format="%(asctime)s %(levelname)s mock_sbatch %(pathname)s:%(lineno)s:%(funcName)s %(message)s", 

295 ) 

296 

297 logging.debug( 

298 "\n".join( 

299 ["mock_sbatch called with args:"] 

300 + ["--{}: {}".format(arg, arg_value) for arg, arg_value in vars(args).items()] 

301 ) 

302 ) 

303 

304 with get_db_file_handle(find_db_file()) as db_file: 

305 job_idx = append_job(db_file, **defined_args) 

306 logging.debug("Appended job {} in DB with args {}".format(job_idx, defined_args)) 

307 

308 # if parsable is set: print jobID 

309 if args.parsable: 309 ↛ 310line 309 didn't jump to line 310 because the condition on line 309 was never true

310 print(job_idx, end="", flush=True) 

311 logging.debug("Printing job ID {} to stdout".format(job_idx)) 

312 

313 # In order to create a process to run the sbatch command and exit sbatch without killing the child process or 

314 # making it a zombie, we need to use the double fork technique used to create deamons: 

315 # see https://stackoverflow.com/questions/473620/how-do-you-create-a-daemon-in-python 

316 # https://stackoverflow.com/questions/881388/what-is-the-reason-for-performing-a-doube-fork-when-creating-a-daemon 

317 # Double forking is used to create a new process B (that is attached to process group of A, its parent), then 

318 # create a new process C and dettach it from A. The new process can then run and be cleaned up by process 1 

319 # instead of A. 

320 # There are a number of other things to consider to completely detach a process, such as closing all file handles, 

321 # here we only do the minimum for the mock to work. 

322 # See https://pagure.io/python-daemon/blob/main/f/src/daemon/daemon.py for a reference implementation of a daemon 

323 # in python (old one, there is a pep associated that was never merged due to lack of interest) 

324 

325 logging.debug("Job {} Parent process will fork child".format(job_idx)) 

326 # fork a first time 

327 pid_1st_fork = os.fork() 

328 if pid_1st_fork == 0: # we are in the child process 328 ↛ 329line 328 didn't jump to line 329 because the condition on line 328 was never true

329 logging.debug("Job {} Child process, creating new process session and becoming group leader".format(job_idx)) 

330 # create new process session for the child 

331 # the 2nd fork process will be in this session, detached from the initial process 

332 os.setsid() 

333 

334 logging.debug("Job {} Child process will fork grandchild".format(job_idx)) 

335 # second fork 

336 pid_2nd_fork = os.fork() 

337 if pid_2nd_fork == 0: # we are in the grandchild: run the sbatch job, then update db 

338 logging.debug("Job {} Grandchild process re-derecting stdin, stdout and stderr".format(job_idx)) 

339 # redirect stdin, stdout and stderr to /dev/null (so we separate from child's file handles) 

340 # (child's file handles are the same than the parent child handles...) 

341 os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdin.fileno()) 

342 os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdout.fileno()) 

343 os.dup2(os.open(os.devnull, os.O_RDWR), sys.stderr.fileno()) 

344 logging.debug("Job {} Grandchild process calls launch_job".format(job_idx)) 

345 launch_job( 

346 args.CMD, 

347 args.stdout.replace("%j", str(job_idx).strip()), 

348 args.stderr.replace("%j", str(job_idx).strip()), 

349 job_idx, 

350 parse_dependency(args.dependency), 

351 ) 

352 # exit without clean up. 

353 # it is REQUIRED to have the grandchild exit in this way if it is tested with pytest, otherwise pytest 

354 # will run the program several times see https://github.com/pytest-dev/pytest/issues/12028 

355 # Side effect: the stdio buffers will not be flushed (detached process can't output to stdio anyway) 

356 os._exit(0) 

357 else: 

358 # exit the child without cleaning up file handlers or flushing buffers, because those are shared with parent 

359 logging.debug("Job {} Child proces exiting without cleaning up.".format(job_idx)) 

360 os._exit(0) 

361 else: 

362 os.wait() # this waits for the child (not the grandchild!) to finish to clean it up 

363 logging.debug("Job {} Parent process exiting after waiting on child process.".format(job_idx)) 

364 

365 

366if __name__ == "__main__": 366 ↛ 367line 366 didn't jump to line 367 because the condition on line 366 was never true

367 main()