forked from facebookresearch/demucs
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun_slurm.py
executable file
·89 lines (77 loc) · 2.64 KB
/
run_slurm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Run training from Slurm on all visible GPUs. Start only
one task per node as this script will spawn one child for each GPU.
This will not schedule a job but instead should be launched from srun/sbatch.
"""
import os
import subprocess as sp
import sys
import time
import torch as th
from demucs.utils import free_port
def main():
args = sys.argv[1:]
gpus = th.cuda.device_count()
n_nodes = int(os.environ['SLURM_JOB_NUM_NODES'])
node_id = int(os.environ['SLURM_NODEID'])
job_id = int(os.environ['SLURM_JOBID'])
rank_offset = gpus * node_id
hostnames = sp.run(['scontrol', 'show', 'hostnames', os.environ['SLURM_JOB_NODELIST']],
capture_output=True,
check=True).stdout
master_addr = hostnames.split()[0].decode('utf-8')
if n_nodes == 1:
port = free_port()
else:
port = 20_000 + (job_id % 40_000)
args += ["--world_size", str(n_nodes * gpus), "--master", f"{master_addr}:{port}"]
tasks = []
print("About to go live", master_addr, node_id, n_nodes, file=sys.stderr)
sys.stderr.flush()
for gpu in range(gpus):
kwargs = {}
if gpu > 0:
kwargs['stdin'] = sp.DEVNULL
kwargs['stdout'] = sp.DEVNULL
# We keep stderr to see tracebacks from children.
tasks.append(
sp.Popen(["python3", "-m", "demucs"] + args +
["--rank", str(rank_offset + gpu)], **kwargs))
tasks[-1].rank = rank_offset + gpu
failed = False
try:
while tasks:
for task in tasks:
try:
exitcode = task.wait(0.1)
except sp.TimeoutExpired:
continue
else:
tasks.remove(task)
if exitcode:
print(f"Task {task.rank} died with exit code "
f"{exitcode}",
file=sys.stderr)
failed = True
else:
print(f"Task {task.rank} exited successfully")
if failed:
break
time.sleep(1)
except KeyboardInterrupt:
for task in tasks:
task.terminate()
raise
if failed:
for task in tasks:
task.terminate()
sp.run(["scancel", str(job_id)], check=True)
sys.exit(1)
if __name__ == "__main__":
main()