#!/usr/bin/env python
"""Implement a mock of the squeue command of slurm.
The squeue -o, --format argument is supported except for the %all directive
"""
import argparse
import datetime
from typing import Callable, List, Tuple
import numpy as np
from mockslurm.process_db import (
    JobReason,
    JobState,
    find_db_file,
    get_db,
    get_db_file_handle,
    get_filtered_DB_mask,
)
from mockslurm.utils import filter_dict_from_args
_SQUEUE_DEFAULT_SHORT_FORMAT = "%.18i %.9P %.8j %.8u %.2t %.10M %.6D %R"
_SQUEUE_DEFAULT_LONG_FORMAT = "%.18i %.9P %.8j %.8u %.8T %.10M %.9l %.6D %R"
_SQUEUE_STATE_CODE_LONG_TO_SHORT = {
    JobState.PENDING: "PD",
    JobState.RUNNING: "R",
    JobState.FAILED: "F",
    JobState.COMPLETED: "CD",
    JobState.CANCELLED: "CA",
}
[docs]
def count_nodes(nodelist: str) -> int:
    """Count the number of nodes requested by users in the nodelist
    Parameters
    ----------
    nodelist : str
        Nodelist string in slurm format: comma separated list of nodes or node ranges
        example: "cp12,cp18", "cp[10-14],cp28"
    Returns
    -------
    int
        Number of nodes concerned by the nodelist
    """
    if not nodelist:  # mock considers no nodelist jobs get allocated 1 node
        return 1
    nodecount = 0
    nlist = nodelist.split(",")
    for node_specs in nlist:
        if "[" in node_specs:
            _, node_specs = node_specs.split("[")  # ignore node name
        if "-" in node_specs:  # count nodes in range
            first, last = node_specs.split("-")
            last = last.split("]")[0]  # ignore closing "]" if any
            nodecount += int(last) - int(first) + 1
        else:  # not a range: single node
            nodecount += 1
    return nodecount 
SQUEUE_FORMAT_LETTER_TO_DB_FIELD = {
    # "": lambda job_info, _: job_info["PID"],
    "i": ("JOBID", lambda _, job_idx: str(job_idx)),
    "A": ("JOBID", lambda _, job_idx: str(job_idx)),
    "j": ("NAME", lambda job_info, _: job_info["NAME"].decode()),
    "u": ("USER", lambda job_info, _: job_info["USER"].decode()),
    "a": ("ACCOUNT", lambda job_info, _: job_info["ACCOUNT"].decode()),
    "P": ("PARTITION", lambda job_info, _: job_info["PARTITION"].decode()),
    "v": ("RESERVATION", lambda job_info, _: job_info["RESERVATION"].decode()),
    "M": (
        "TIME",
        lambda job_info,
        _: (  # split is to remove the floating point part, that contains microseconds that should not be displayed
            str(datetime.datetime.now() - datetime.datetime.fromtimestamp(job_info["START_TIME"])).split(".")[0]
            if job_info["STATE"] == JobState.RUNNING
            else "0:00:00"
        ),
    ),
    "l": ("TIME_LIMIT", lambda job_info, _: "UNLIMITED"),  # no time limit in mock
    "n": ("REQ_NODES", lambda job_info, _: job_info["NODELIST"].decode()),
    "N": ("NODELIST", lambda job_info, _: job_info["NODELIST"].decode()),
    "D": ("NODES", lambda job_info, _: str(count_nodes(job_info["NODELIST"].decode()))),
    "S": ("START_TIME", lambda job_info, _: job_info["START_TIME"]),
    "V": (
        "SUBMIT_TIME",
        lambda job_info, _: job_info["START_TIME"].decode(),
    ),  # equal to start time for mock)
    "o": ("COMMAND", lambda job_info, _: job_info["CMD"].decode()),
    "r": ("REASON", lambda job_info, _: JobReason(job_info["REASON"]).name),
    "t": ("ST", lambda job_info, _: _SQUEUE_STATE_CODE_LONG_TO_SHORT[job_info["STATE"]]),
    "T": ("STATE", lambda job_info, _: JobState(job_info["STATE"]).name),
    "R": (
        "NODELIST(REASON)",
        lambda job_info, _: (
            "(" + JobReason(job_info["REASON"]).name + ")"
            if JobReason(job_info["STATE"]) == JobState.PENDING
            or JobReason(job_info["REASON"]) == JobReason.NonZeroExitCode
            else job_info["NODELIST"].decode()
        ),
    ),
    # "": lambda job_info, _: job_info["EXIT_CODE"],
}
[docs]
def main():
    parser = argparse.ArgumentParser(
        description="Slurm scancel mock", formatter_class=argparse.ArgumentDefaultsHelpFormatter, add_help=False
    )
    user_group = parser.add_mutually_exclusive_group()
    parser.add_argument(
        "--account",
        "-A",
        type=str,
        dest="ACCOUNT",
        help="Specify the accounts of the jobs to view. Accepts a comma separated list of account names",
    )
    parser.add_argument(
        "--name",
        "-n",
        type=str,
        dest="NAME",
        help="Request jobs having one of the specified names. The list consists of a comma separated list of job names.",
    )
    user_group.add_argument(
        "--me",
        action="store_true",
        dest="me",
        help="Equivalent to --user=<my username>",
    )
    parser.add_argument(
        "--nodelist",
        "-w",
        type=str,
        dest="NODELIST",
        help="Report only on jobs allocated to the specified node or list of nodes",
    )
    parser.add_argument(
        "--format",
        "-o",
        type=str,
        default=_SQUEUE_DEFAULT_SHORT_FORMAT,
        dest="format_str",
        help="Specify the information to be displayed, its size and position",
    )
    parser.add_argument(
        "--noheader", "-h", action="store_true", dest="no_header", help="Do not print a header on the output"
    )
    parser.add_argument(
        "--long",
        "-l",
        action="store_true",
        dest="long",
        help="Report more of the available information for the selected jobs or job steps, subject to any constraints specified",
    )
    parser.add_argument(
        "--partition",
        "-p",
        type=str,
        dest="PARTITION",
        help="Specify the partitions of the jobs or steps to view. Accepts a comma separated list of partition names",
    )
    parser.add_argument(
        "--reservation",
        "-R",
        type=str,
        dest="RESERVATION",
        help="Specify the reservation of the jobs to view",
    )
    parser.add_argument(
        "--usage",
        action="store_true",
        dest="print_help",
        help="Print a brief help message listing the squeue options",
    )
    parser.add_argument(
        "--jobs",
        "-j",
        type=str,
        dest="jobs",
        help="Specify a comma separated list of job IDs to display, Defaults to all jobs",
    )
    args = parser.parse_args()
    if args.print_help:
        parser.print_help()
    if args.long and not args.format_str:
        args.format_str = _SQUEUE_DEFAULT_LONG_FORMAT
    # Split args that take list in comma separated string to list!
    args.ACCOUNT = args.ACCOUNT.split(",") if args.ACCOUNT is not None else None
    args.NAME = args.NAME.split(",") if args.NAME is not None else None
    args.PARTITION = args.PARTITION.split(",") if args.PARTITION is not None else None
    args.jobs = [int(job_id) for job_id in args.jobs.split(",")] if args.jobs is not None else None
    field_filter_values = filter_dict_from_args(args)
    # filter out finished jobs
    field_filter_values["STATE"] = [JobState.RUNNING, JobState.PENDING]
    with get_db_file_handle(find_db_file()) as db_file:
        # Get mask to select DB rows
        mask = get_filtered_DB_mask(db_file, field_filter_values)
        # filter job IDs if some were specified
        if args.jobs and len(mask) > 0:  # if mask is empty (no jobs in DB), skip
            # job IDs are just the index of the jobs in the DB!
            # remove (silently like real squeue...) job IDs that do not exist
            args.jobs = [job_ID for job_ID in args.jobs if 0 <= job_ID < len(mask)]
            if not args.jobs:  # no jobs remaining: exit with error like squeue
                print("slurm_load_jobs error: Invalid job id specified")
                exit(1)
            mask[args.jobs] = True
        # Get format string and function formating field value based on squeue --format
        format_str, fields_header, fields_filler_fct = parse_squeue_format(args.format_str)
        # Print header
        if not args.no_header:
            print(format_str.format(*fields_header))
        # Print the jobs found in filtered DB
        job_indices = np.nonzero(mask)[0]
        for idx, job in zip(job_indices, get_db(db_file)[mask]):
            print(format_str.format(*[fct(job, idx) for fct in fields_filler_fct])) 
if __name__ == "__main__":
    main()