Skip to content
Snippets Groups Projects
job.py 2.36 KiB
Newer Older
#!/usr/bin/python

import os
import subprocess
import argparse

def makejob(commit_id, model, nruns, user, time_wall):
    return f"""#!/bin/bash 

#SBATCH --job-name={model}
#SBATCH --nodes=1
#SBATCH --partition=gpu_prod_night
#SBATCH --time={time_wall}
#SBATCH --output=logslurms/slurm-%A_%a.out
#SBATCH --error=logslurms/slurm-%A_%a.err
#SBATCH --array=0-{nruns}


current_dir=`pwd`

echo "Session " {model}_${{SLURM_ARRAY_JOB_ID}}_${{SLURM_ARRAY_TASK_ID}}

echo "Copying the source directory and data"
date
mkdir $TMPDIR/projet_dl
rsync -r . $TMPDIR/projet_dl

echo "Checking out the correct version of the code commit_id {commit_id}"
cd $TMPDIR/pprojet_dl
git checkout {commit_id}


echo "Setting up the virtual environment"
python3 -m pip install virtualenv --user
virtualenv -p python3 venv
source venv/bin/activate
python -m pip install -r requirements.txt

echo "Running main.py"
python3 main.py --logDir /usr/users/sdi1/sdi1_3/Projet_DL/Kaggle_Phytoplankton/logs/ --no_wandb

if [[ $? != 0 ]]; then
    exit -1
fi

# Once the job is finished, you can copy back back
# files from $TMPDIR/emnist to $current_dir

"""

def submit_job(job):
    with open('job.sbatch', 'w') as fp:
        fp.write(job)
    os.system("sbatch job.sbatch")

# Ensure all the modified files have been staged and commited
result = int(subprocess.run("expr $(git diff --name-only | wc -l) + $(git diff --name-only --cached | wc -l)",
                            shell=True, stdout=subprocess.PIPE).stdout.decode())
if result > 0:
    print(f"We found {result} modifications either not staged or not commited")
    raise RuntimeError("You must stage and commit every modification before submission ")

commit_id = subprocess.check_output("git log --pretty=format:'%H' -n 1", shell=True).decode()

parser = argparse.ArgumentParser()

parser.add_argument("--time_wall",
                    default="no-limit",
                    help="Time wall. Choose in [no-limit, hour, half, quarter]")

parser.add_argument("--model_name",
                    default ="Bi-LSTM",
                    help="Name of the model to train")

# Ensure the log directory exists
os.system("mkdir -p logslurms")

args = parser.parse_args()

time_wall = {"no_limit": "48:00:00","hour" : "1:00:00", "half" : "0:00:00", "quarter" : "0:00:15"}

# Launch the batch jobs
submit_job(makejob(commit_id, args.model_name, 1, args.user, time_wall[args.time_wall]))