Coverage for mockslurm/mock_sbatch.py: 66%

105 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-04 23:38 +0000

1#!/usr/bin/env python 

2"""Implement a mock of the sbatch command of slurm. 

3 

4The implementation mimics the sbatch API, while using subprocess to  

5start the processes on the host machine without slurm daemons. 

6 

7Double fork is used to detach the subprocess, allowing the sbatch 

8mock to return while the created process is still running. 

9""" 

10 

11import argparse 

12import os 

13import subprocess as sp 

14import time 

15from enum import Enum 

16from multiprocessing import Process 

17from typing import Dict, List, Tuple 

18 

19from mockslurm.process_db import ( 

20 JobReason, 

21 JobState, 

22 append_job, 

23 find_db_file, 

24 get_db, 

25 get_db_file_handle, 

26 update_db_value, 

27) 

28 

29 

30class DependencyActions(Enum): 

31 OK = 1 

32 WAIT = 2 

33 NEVER = 3 

34 

35 

36def parse_dependency(job_dependency: str) -> Tuple[bool, List[Dict[str, List[int]]]]: 

37 """Parse a job's dependency specification string in sbatch format 

38 

39 Parameters 

40 ---------- 

41 job_dependency : str 

42 sbatch job's dependency string. Examples: afterok:20:21,afterany:23 

43 

44 Returns 

45 ------- 

46 Tuple[bool, List[Dict[str, List[int]]]] 

47 Tuple 1st element is whether all dependencies must be satisfied for a job to start (True, "," in slurm str), 

48 or if a single dependency is enough (False, "?" in slurm str). When a single dependency is present: True. 

49 The list values are dictionaries with a single key been the dependency type ("afterok", "afterany", etc.) 

50 and the values been the job IDs for these dependency 

51 Example 

52 ------- 

53 >>> parse_dependency("afterok:20:21?afterany:23") 

54 (True, [{"afterok": [20, 21]}, {"afterany": [23]}]) 

55 """ 

56 if "," in job_dependency and "?" in job_dependency: 

57 raise ValueError('Multiple separator "?", ","in job dependency not supported.') 

58 separator = "?" if "?" in job_dependency else "," # default is "," if only 1 dep 

59 dependencies = [] 

60 while job_dependency: 

61 if job_dependency[0] == separator: 

62 job_dependency = job_dependency[1:] 

63 

64 dep_str = job_dependency.split(separator, maxsplit=1)[0] 

65 job_dependency = job_dependency[len(dep_str) :] 

66 split = dep_str.split(":") 

67 dependencies.append({split[0]: [int(v) for v in split[1:]]}) 

68 

69 return (separator == ",", dependencies) 

70 

71 

72def check_dependency(job_dependency: List[Tuple[bool, Dict[str, List[int]]]]) -> str: 

73 """Check if a job's dependency are satisfied 

74 

75 Parameters 

76 ---------- 

77 job_dependency : Tuple[bool, List[Dict[str, List[int]]]] 

78 Job dependencies. Each tuple in the list corresponds to a dependency specification. 

79 See `mockslurm.mock_sbatch.parse_dependency` 

80 

81 Warning 

82 ------- 

83 This function only supports the "afterok" dependency specification, anything else causes 

84 a `NotImplementedError` to be raised. 

85 

86 Raise 

87 ----- 

88 NotImplementedError 

89 If the dependency type is not "afterok" 

90 

91 Returns 

92 ------- 

93 DependencyActions 

94 The action to follow based on the dependency evaluation, for instance to wait, to start job, etc. 

95 """ 

96 if not job_dependency[1]: 

97 return DependencyActions.OK 

98 

99 deps_value = [] 

100 with get_db_file_handle(find_db_file()) as db_file: 

101 db = get_db(db_file) 

102 for dep in job_dependency[1]: 

103 internal_deps_ok = True 

104 for dep_type, job_idx in dep.items(): 

105 if dep_type != "afterok": 

106 raise NotImplementedError("Dependency type {} is not implemented.".format(dep_type)) 

107 dep_state = db[job_idx if isinstance(job_idx, list) else [job_idx]]["STATE"] 

108 if any(dep_state == JobState.FAILED): 

109 return DependencyActions.NEVER 

110 internal_deps_ok &= all(dep_state == JobState.COMPLETED) 

111 

112 deps_value.append(internal_deps_ok) 

113 

114 combined_dep = all(deps_value) if job_dependency[0] else any(deps_value) 

115 return DependencyActions.OK if combined_dep else DependencyActions.WAIT 

116 

117 

118def launch_job( 

119 cmd: str, stdout: str, stderr: str, job_idx: int, job_dependency: List[Tuple[bool, Dict[str, List[int]]]] 

120): 

121 """Runs `cmd` as a detached subprocess. 

122 

123 The return code of `cmd` is retrieved by the detached subprocess and updated in the DB. 

124 

125 Parameters 

126 ---------- 

127 cmd : str 

128 Command to run, similar to the content of the --wrap= argument of sbatch 

129 stdout : str 

130 Path to a file where the output of `cmd` will be piped 

131 stderr : str 

132 Path to a file where the errors of `cmd` will be piped 

133 job_idx : int 

134 Index of the job in the DB 

135 job_dependency : List[Tuple[bool, Dict[str, List[int]]]] 

136 Job dependencies. Each tuple in the list corresponds to a dependency specification. 

137 See `mockslurm.mock_sbatch.parse_dependency` 

138 """ 

139 # In order for this function to keep running while initial process exists 

140 # we need to "double-fork", to get out of the parent's process group 

141 # to do this: - call fork. 

142 # - If we stayed in the parent process, we get 0 -> return immediately 

143 # - If we are the child (fork value == child PID != 0): do stuff 

144 # https://stackoverflow.com/questions/5631624/how-to-get-exit-code-when-using-python-subprocess-communicate-method 

145 # WARNING: we need to do this before starting the process, otherwise 

146 # return code of executed process is invalid (always 0) 

147 if os.fork() != 0: 

148 return 

149 # Wait for dependencies to be ready 

150 dependency_check = check_dependency(job_dependency) 

151 if dependency_check == DependencyActions.WAIT: 

152 with get_db_file_handle(find_db_file()) as db_file: 

153 db = get_db(db_file) 

154 update_db_value(db_file, job_idx, key="REASON", value=JobReason.Dependency) 

155 while dependency_check == DependencyActions.WAIT: 

156 time.sleep(0.25) 

157 dependency_check = check_dependency(job_dependency) 

158 

159 # If not ok: do not start job and mark its state as 

160 if dependency_check != DependencyActions.OK: 

161 with get_db_file_handle(find_db_file()) as db_file: 

162 db = get_db(db_file) 

163 update_db_value(db_file, job_idx, key="REASON", value=JobReason.DependencyNeverSatisfied) 

164 update_db_value(db_file, job_idx, key="STATE", value=JobState.FAILED) 

165 else: 

166 # can't use stdout, stderr of subpr 

167 # ocess since we detach process and exit parent 

168 # (so file descriptors would be closed) so use shell redirection 

169 # Can not use shlex here to split command, otherwise output redirection is treated as 

170 # another argument (surrounded by "") and treated as another argument by the command 

171 # so we use a single string with shell=True 

172 cmd += " 1>>" + stdout + " 2>>" + stderr # + "; echo $?" 

173 # lock the DB here, so STATE can not change before we start the process 

174 with get_db_file_handle(find_db_file()) as db_file: 

175 db = get_db(db_file) 

176 if db[job_idx]["STATE"] == JobState.PENDING: 

177 try: 

178 p = sp.Popen(cmd, start_new_session=True, shell=True) 

179 update_db_value(db_file, job_idx, key="PID", value=p.pid) 

180 update_db_value(db_file, job_idx, key="STATE", value=JobState.RUNNING) 

181 update_db_value(db_file, job_idx, key="REASON", value=JobReason.NOREASON) 

182 except: 

183 update_db_value(db_file, job_idx, key="STATE", value=JobState.FAILED) 

184 update_db_value(db_file, job_idx, key="REASON", value=JobReason.JobLaunchFailure) 

185 raise 

186 else: 

187 # Process is not pending anymore, it must have been killed already, so we won't start it 

188 return 

189 

190 # wait for process to be done 

191 exit_code = p.wait() 

192 with get_db_file_handle(find_db_file()) as db_file: 

193 update_db_value(db_file, job_idx, key="EXIT_CODE", value=exit_code) 

194 update_db_value( 

195 db_file, job_idx, key="STATE", value=JobState.COMPLETED if exit_code == 0 else JobState.FAILED 

196 ) 

197 update_db_value( 

198 db_file, job_idx, key="REASON", value=JobReason.NOREASON if exit == 0 else JobReason.NonZeroExitCode 

199 ) 

200 

201 

202def main(): 

203 parser = argparse.ArgumentParser( 

204 description="Slurm sbtach mock.", formatter_class=argparse.ArgumentDefaultsHelpFormatter 

205 ) 

206 parser.add_argument("--account", "-A", type=str, dest="ACCOUNT", help="user account", required=False) 

207 parser.add_argument( 

208 "--dependency", 

209 "-d", 

210 type=str, 

211 default="", 

212 dest="dependency", 

213 help="Defer the start of this job until the specified dependencies have been satisfied", 

214 ) 

215 parser.add_argument( 

216 "--error", "-e", type=str, dest="stderr", help="error file of slurm job", default="slurm-%j.out" 

217 ) 

218 parser.add_argument("--job-name", "-J", type=str, dest="NAME", help="job name", default="wrap") 

219 parser.add_argument( 

220 "--nodelist", 

221 "-w", 

222 type=str, 

223 dest="NODELIST", 

224 help="Request a specific list of hosts", 

225 nargs="*", 

226 ) 

227 parser.add_argument( 

228 "--output", "-o", type=str, dest="stdout", help="output file of slurm job", default="slurm-%j.out" 

229 ) 

230 parser.add_argument("--partition", "-p", type=str, dest="PARTITION", help="job partition", required=False) 

231 parser.add_argument( 

232 "--parsable", 

233 action="store_true", 

234 dest="parsable", 

235 help="Outputs only the job id number. Errors will still be displayed", 

236 ) 

237 parser.add_argument("--reservation", type=str, dest="RESERVATION", help="job reservation", required=False) 

238 parser.add_argument( 

239 "--wrap", 

240 type=str, 

241 dest="CMD", 

242 help="Command to be executed", 

243 required=True, 

244 ) 

245 

246 args, _ = parser.parse_known_args() 

247 defined_args = {arg: value for arg, value in vars(args).items() if value is not None} 

248 defined_args.pop("stdout") 

249 defined_args.pop("stderr") 

250 defined_args.pop("parsable") 

251 defined_args.pop("dependency") 

252 

253 with get_db_file_handle(find_db_file()) as db_file: 

254 job_idx = append_job(db_file, **defined_args) 

255 

256 # TODO: raise error if arguments to sbatch aren't valid (reservation, partition, nodelist ? etc.) 

257 

258 detach_p = Process( 

259 target=lambda: launch_job( 

260 args.CMD, 

261 args.stdout.replace("%j", str(job_idx)).strip(), 

262 args.stderr.replace("%j", str(job_idx)).strip(), 

263 job_idx, 

264 parse_dependency(args.dependency), 

265 ) 

266 ) 

267 detach_p.start() 

268 detach_p.join() # wait for the detach process that stayed in the parent process group 

269 

270 # if parsable is set: print jobID 

271 if args.parsable: 

272 print(job_idx, end="") 

273 

274 

275if __name__ == "__main__": 

276 main()