Coverage for mockslurm/mock_squeue.py: 94%

79 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-09-04 23:38 +0000

1#!/usr/bin/env python 

2"""Implement a mock of the squeue command of slurm. 

3 

4The squeue -o, --format argument is supported except for the %all directive 

5""" 

6 

7import argparse 

8import datetime 

9from collections import defaultdict 

10from typing import Callable, List, Tuple 

11 

12import numpy as np 

13from mockslurm.process_db import ( 

14 JobReason, 

15 JobState, 

16 find_db_file, 

17 get_db, 

18 get_db_file_handle, 

19 get_filtered_DB_mask, 

20) 

21from mockslurm.utils import filter_dict_from_args 

22 

23_SQUEUE_DEFAULT_SHORT_FORMAT = "%.18i %.9P %.8j %.8u %.2t %.10M %.6D %R" 

24_SQUEUE_DEFAULT_LONG_FORMAT = "%.18i %.9P %.8j %.8u %.8T %.10M %.9l %.6D %R" 

25 

26_SQUEUE_STATE_CODE_LONG_TO_SHORT = { 

27 JobState.PENDING: "PD", 

28 JobState.RUNNING: "R", 

29 JobState.FAILED: "F", 

30 JobState.COMPLETED: "CD", 

31 JobState.CANCELLED: "CA", 

32} 

33 

34 

35def count_nodes(nodelist: str) -> int: 

36 """Count the number of nodes requested by users in the nodelist 

37 

38 Parameters 

39 ---------- 

40 nodelist : str 

41 Nodelist string in slurm format: comma separated list of nodes or node ranges 

42 example: "cp12,cp18", "cp[10-14],cp28" 

43 

44 Returns 

45 ------- 

46 int 

47 Number of nodes concerned by the nodelist 

48 """ 

49 if not nodelist: # mock considers no nodelist jobs get allocated 1 node 

50 return 1 

51 

52 nodecount = 0 

53 nlist = nodelist.split(",") 

54 for node_specs in nlist: 

55 if "[" in node_specs: 

56 _, node_specs = node_specs.split("[") # ignore node name 

57 if "-" in node_specs: # count nodes in range 

58 first, last = node_specs.split("-") 

59 last = last.split("]")[0] # ignore closing "]" if any 

60 nodecount += int(last) - int(first) + 1 

61 else: # not a range: single node 

62 nodecount += 1 

63 

64 return nodecount 

65 

66 

67SQUEUE_FORMAT_LETTER_TO_DB_FIELD = { 

68 # "": lambda job_info, _: job_info["PID"], 

69 "i": ("JOBID", lambda _, job_idx: str(job_idx)), 

70 "A": ("JOBID", lambda _, job_idx: str(job_idx)), 

71 "j": ("NAME", lambda job_info, _: job_info["NAME"].decode()), 

72 "u": ("USER", lambda job_info, _: job_info["USER"].decode()), 

73 "a": ("ACCOUNT", lambda job_info, _: job_info["ACCOUNT"].decode()), 

74 "P": ("PARTITION", lambda job_info, _: job_info["PARTITION"].decode()), 

75 "v": ("RESERVATION", lambda job_info, _: job_info["RESERVATION"].decode()), 

76 "M": ( 

77 "TIME", 

78 lambda job_info, _: ( # split is to remove the floating point part, that contains microseconds that should not be displayed 

79 str(datetime.datetime.now() - datetime.datetime.fromtimestamp(job_info["START_TIME"])).split(".")[0] 

80 if job_info["STATE"] == JobState.RUNNING 

81 else "0:00:00" 

82 ), 

83 ), 

84 "l": ("TIME_LIMIT", lambda job_info, _: "UNLIMITED"), # no time limit in mock 

85 "n": ("REQ_NODES", lambda job_info, _: job_info["NODELIST"].decode()), 

86 "N": ("NODELIST", lambda job_info, _: job_info["NODELIST"].decode()), 

87 "D": ("NODES", lambda job_info, _: str(count_nodes(job_info["NODELIST"].decode()))), 

88 "S": ("START_TIME", lambda job_info, _: job_info["START_TIME"].decode()), 

89 "V": ( 

90 "SUBMIT_TIME", 

91 lambda job_info, _: job_info["START_TIME"].decode(), 

92 ), # equal to start time for mock) 

93 "o": ("COMMAND", lambda job_info, _: job_info["CMD"].decode()), 

94 "r": ("REASON", lambda job_info, _: job_info["REASON"].name), 

95 "t": ("ST", lambda job_info, _: _SQUEUE_STATE_CODE_LONG_TO_SHORT[job_info["STATE"]]), 

96 "T": ("STATE", lambda job_info, _: JobState(job_info["STATE"]).name), 

97 "R": ( 

98 "NODELIST(REASON)", 

99 lambda job_info, _: ( 

100 "(" + JobReason(job_info["REASON"]).name + ")" 

101 if JobReason(job_info["STATE"]) == JobState.PENDING 

102 or JobReason(job_info["REASON"]) == JobReason.NonZeroExitCode 

103 else job_info["NODELIST"].decode() 

104 ), 

105 ), 

106 # "": lambda job_info, _: job_info["EXIT_CODE"], 

107} 

108 

109 

110def parse_squeue_format(squeue_format_str: str) -> Tuple[str, List[Callable]]: 

111 """Convert squeue format argument (-o format) to a python format string and 

112 list of function filling the values of the format string for each job DB row. 

113 

114 Parameters 

115 ---------- 

116 squeue_format_str : str 

117 Squeue format string, eg "%.18i %.9P %.8j %.8u %.2t %.10M %.6D %R". 

118 

119 Warning 

120 ------- 

121 Does not support squeue's "%all" formatting string. 

122 

123 Returns 

124 ------- 

125 Tuple[str, List[Callable]] 

126 python format string, and a list of callable that should be used to fill the values 

127 in the formatted string with a job DB row and job_index. 

128 

129 Examples 

130 -------- 

131 >>> format_str, fields_filler_fcts = parse_squeue_format("%.18i %.9P %.8j %.8u %.2t %.10M %.6D %R") 

132 >>> squeue_output = format_str.format(*[fct(job_DB_row, row_idx) for fct in fields_filler_fcts]) 

133 """ 

134 # Iterate on the format string and for each found % 

135 # add a "{:ndigit}" to the python string, with optional ">" if "." follows "%" 

136 # ndigit is found by converting the characters following % 

137 # the function filling the value is taken from the map of squeue format letters to DB function 

138 python_format_string = "" 

139 fields_header = [] 

140 fields_filler_functions = [] 

141 # split 1st value is empty if string starts with delimiter 

142 for field_format_str in squeue_format_str.split("%")[1:]: 

143 python_format_string += "{:" 

144 if field_format_str[0] == ".": 

145 python_format_string += ">" 

146 field_format_str = field_format_str[1:] 

147 idx = 0 

148 while field_format_str[idx].isdigit(): 

149 python_format_string += field_format_str[idx] 

150 idx += 1 

151 header, fct = SQUEUE_FORMAT_LETTER_TO_DB_FIELD[field_format_str[idx]] 

152 fields_header.append(header) 

153 fields_filler_functions.append(fct) 

154 python_format_string += "}" + field_format_str[idx + 1 :] 

155 

156 return python_format_string, fields_header, fields_filler_functions 

157 

158 

159def main(): 

160 parser = argparse.ArgumentParser( 

161 description="Slurm scancel mock", formatter_class=argparse.ArgumentDefaultsHelpFormatter, add_help=False 

162 ) 

163 user_group = parser.add_mutually_exclusive_group() 

164 parser.add_argument( 

165 "--account", 

166 "-A", 

167 type=str, 

168 dest="ACCOUNT", 

169 help="Specify the accounts of the jobs to view. Accepts a comma separated list of account names", 

170 ) 

171 parser.add_argument( 

172 "--name", 

173 "-n", 

174 type=str, 

175 dest="NAME", 

176 help="Request jobs having one of the specified names. The list consists of a comma separated list of job names.", 

177 ) 

178 user_group.add_argument( 

179 "--me", 

180 action="store_true", 

181 dest="me", 

182 help="Equivalent to --user=<my username>", 

183 ) 

184 parser.add_argument( 

185 "--nodelist", 

186 "-w", 

187 type=str, 

188 dest="NODELIST", 

189 help="Report only on jobs allocated to the specified node or list of nodes", 

190 ) 

191 parser.add_argument( 

192 "--format", 

193 "-o", 

194 type=str, 

195 default=_SQUEUE_DEFAULT_SHORT_FORMAT, 

196 dest="format_str", 

197 help="Specify the information to be displayed, its size and position", 

198 ) 

199 parser.add_argument( 

200 "--noheader", "-h", action="store_true", dest="no_header", help="Do not print a header on the output" 

201 ) 

202 parser.add_argument( 

203 "--long", 

204 "-l", 

205 action="store_true", 

206 dest="long", 

207 help="Report more of the available information for the selected jobs or job steps, subject to any constraints specified", 

208 ) 

209 parser.add_argument( 

210 "--partition", 

211 "-p", 

212 type=str, 

213 dest="PARTITION", 

214 help="Specify the partitions of the jobs or steps to view. Accepts a comma separated list of partition names", 

215 ) 

216 parser.add_argument( 

217 "--reservation", 

218 "-R", 

219 type=str, 

220 dest="RESERVATION", 

221 help="Specify the reservation of the jobs to view", 

222 ) 

223 parser.add_argument( 

224 "--usage", action="store_true", dest="print_help", help="Print a brief help message listing the squeue options" 

225 ) 

226 parser.add_argument( 

227 "--jobs", 

228 "-j", 

229 type=str, 

230 dest="jobs", 

231 help="Specify a comma separated list of job IDs to display, Defaults to all jobs", 

232 ) 

233 args = parser.parse_args() 

234 

235 if args.print_help: 

236 parser.print_help() 

237 

238 if args.long and not args.format_str: 

239 args.format_str = _SQUEUE_DEFAULT_LONG_FORMAT 

240 

241 # Split args that take list in comma separated string to list! 

242 args.ACCOUNT = args.ACCOUNT.split(",") if args.ACCOUNT is not None else None 

243 args.NAME = args.NAME.split(",") if args.NAME is not None else None 

244 args.PARTITION = args.PARTITION.split(",") if args.PARTITION is not None else None 

245 

246 field_filter_values = filter_dict_from_args(args) 

247 # filter out finished jobs 

248 field_filter_values["STATE"] = [JobState.RUNNING, JobState.PENDING] 

249 

250 with get_db_file_handle(find_db_file()) as db_file: 

251 # Get mask to select DB rows 

252 mask = get_filtered_DB_mask(db_file, field_filter_values) 

253 # filter job IDs if some were specified 

254 if args.jobs: 

255 mask[args.jobs] = True 

256 

257 # Get format string and function formating field value based on squeue --format 

258 format_str, fields_header, fields_filler_fct = parse_squeue_format(args.format_str) 

259 

260 # Print header 

261 if not args.no_header: 

262 print(format_str.format(*fields_header)) 

263 

264 # Print the jobs found in filtered DB 

265 job_indices = np.nonzero(mask)[0] 

266 for idx, job in zip(job_indices, get_db(db_file)[mask]): 

267 print(format_str.format(*[fct(job, idx) for fct in fields_filler_fct])) 

268 

269 

270if __name__ == "__main__": 

271 main()