Skip to content

Commit 0d3c0c2

Browse files
authored
[None] [chore] Enhancements and clean up to slurm scripts (#9493)
Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
1 parent 389b73c commit 0d3c0c2

File tree

5 files changed

+89
-85
lines changed

5 files changed

+89
-85
lines changed

examples/disaggregated/slurm/benchmark/config.yaml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ slurm:
55
account: "<account>"
66
job_time: "02:00:00"
77
job_name: "<job_name>"
8+
extra_args: "" # Cluster specific arguments, e.g. "--gres=gpu:4 --exclude=node1,node2"
89
numa_bind: true # Only enable for GB200 NVL72
910

1011
# Benchmark Mode
@@ -34,12 +35,14 @@ environment:
3435
build_wheel: false # Don't build the wheel when launching multiple jobs
3536
trtllm_wheel_path: "" # Path to pre-built TensorRT-LLM wheel. If provided, install from this wheel instead
3637
work_dir: "<full_path_to_work_dir>"
37-
worker_env_var: "TLLM_LOG_LEVEL=INFO TRTLLM_SERVER_DISABLE_GC=1 TRTLLM_WORKER_DISABLE_GC=1"
38-
server_env_var: ""
38+
worker_env_var: "TLLM_LOG_LEVEL=INFO TRTLLM_SERVER_DISABLE_GC=1 TRTLLM_WORKER_DISABLE_GC=1 TRTLLM_ENABLE_PDL=1 ENROOT_ALLOW_DEV=yes"
39+
server_env_var: "TRTLLM_SERVER_DISABLE_GC=1"
3940

4041
# Profiling Configuration
4142
profiling:
4243
nsys_on: false # Set to true to enable profiling
44+
ctx_profile_range: "10-30" # Set TLLM_PROFILE_START_STOP for ctx workers
45+
gen_profile_range: "200-250" # Set TLLM_PROFILE_START_STOP for gen workers
4346

4447
# Accuracy Configuration
4548
accuracy:

examples/disaggregated/slurm/benchmark/disaggr_torch.slurm

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,20 @@ trtllm_wheel_path=${28}
4040

4141
# Profiling
4242
nsys_on=${29}
43+
ctx_profile_range=${30}
44+
gen_profile_range=${31}
4345

4446
# Accuracy evaluation
45-
enable_accuracy_test=${30}
46-
accuracy_model=${31}
47-
accuracy_tasks=${32}
48-
model_args_extra=${33}
47+
enable_accuracy_test=${32}
48+
accuracy_model=${33}
49+
accuracy_tasks=${34}
50+
model_args_extra=${35}
4951

5052
# Worker environment variables
51-
worker_env_var=${34}
53+
worker_env_var=${36}
5254

5355
# Server environment variables
54-
server_env_var=${35}
56+
server_env_var=${37}
5557

5658
# Print all parsed arguments
5759
echo "Parsed arguments:"
@@ -90,6 +92,8 @@ echo " build_wheel: ${build_wheel}"
9092
echo " trtllm_wheel_path: ${trtllm_wheel_path}"
9193
echo " work_dir: ${work_dir}"
9294
echo " nsys_on: ${nsys_on}"
95+
echo " ctx_profile_range: ${ctx_profile_range}"
96+
echo " gen_profile_range: ${gen_profile_range}"
9397
echo
9498
echo "Accuracy Configuration:"
9599
echo " enable_accuracy_test: ${enable_accuracy_test}"
@@ -169,17 +173,6 @@ elif [ -d "${trtllm_repo}" ]; then
169173
echo "TensorRT-LLM installation completed successfully"
170174
fi
171175

172-
# Get enable_pdl from gen config
173-
enable_pdl=$(python3 -c "import yaml; import sys;
174-
try:
175-
with open('${gen_config_path}') as f:
176-
c = yaml.safe_load(f)
177-
print(str(not c.get('enable_attention_dp', True)).lower())
178-
except Exception as e:
179-
print(f'Error reading config: {e}', file=sys.stderr)
180-
sys.exit(1)
181-
")
182-
183176
# Get node lists
184177
all_nodes=($(scontrol show hostname $SLURM_NODELIST | sort))
185178
total_nodes_num=${#all_nodes[@]}
@@ -211,7 +204,7 @@ for i in $(seq 0 $((num_gen_servers - 1))); do
211204
--container-mounts=${container_mount} \
212205
--mpi=pmix \
213206
bash ${work_dir}/start_worker.sh \
214-
"GEN" ${i} ${model_path} "8336" "${benchmark_mode}" "${concurrency_list}" "${enable_pdl}" "${numa_bind}" "${full_logdir}" "${nsys_on}" "${gen_config_path}" "${worker_env_var}" \
207+
"GEN" ${i} ${model_path} "8336" "${benchmark_mode}" "${concurrency_list}" "${numa_bind}" "${full_logdir}" "${nsys_on}" "${gen_profile_range}" "${gen_config_path}" "${worker_env_var}" \
215208
&> ${full_logdir}/output_gen_${i}.log &
216209
done
217210

@@ -226,7 +219,7 @@ for i in $(seq 0 $((num_ctx_servers - 1))); do
226219
--container-mounts=${container_mount} \
227220
--mpi=pmix \
228221
bash ${work_dir}/start_worker.sh \
229-
"CTX" ${i} ${model_path} "8336" "${benchmark_mode}" "${concurrency_list}" "${enable_pdl}" "${numa_bind}" "${full_logdir}" "${nsys_on}" "${ctx_config_path}" "${worker_env_var}" \
222+
"CTX" ${i} ${model_path} "8336" "${benchmark_mode}" "${concurrency_list}" "${numa_bind}" "${full_logdir}" "${nsys_on}" "${ctx_profile_range}" "${ctx_config_path}" "${worker_env_var}" \
230223
&> ${full_logdir}/output_ctx_${i}.log &
231224
done
232225

examples/disaggregated/slurm/benchmark/start_worker.sh

Lines changed: 15 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,14 @@ model_path=${3}
99
port=${4}
1010
benchmark_mode=${5}
1111
concurrency=${6}
12-
enable_pdl=${7}
13-
numa_bind=${8}
14-
log_dir=${9}
15-
enable_nsys=${10}
12+
numa_bind=${7}
13+
log_dir=${8}
14+
enable_nsys=${9}
15+
profile_range=${10}
1616
config_file=${11}
1717
worker_env_var=${12}
1818

1919
unset UCX_TLS
20-
echo "enable_pdl: ${enable_pdl}, log_dir: ${log_dir}"
2120
echo "SLURM_PROCID: ${SLURM_PROCID}, hostname: $(hostname), instance_id: ${instance_id}"
2221

2322
# Export worker environment variables from config
@@ -26,10 +25,6 @@ for env_var in ${worker_env_var}; do
2625
echo "Exported: ${env_var}"
2726
done
2827

29-
if [ "${enable_pdl}" = "true" ]; then
30-
export TRTLLM_ENABLE_PDL=1
31-
fi
32-
3328
if [ "${numa_bind}" = "true" ]; then
3429
numa_bind_cmd="numactl -m 0,1"
3530
echo "numactl -m 0,1 - Only allocate memory from nodes on GB200"
@@ -45,34 +40,27 @@ fi
4540

4641
echo "config_file: ${config_file}"
4742

48-
# save the hostname to a file
49-
50-
# if SLURM_NODEID is 0
43+
# if SLURM_NODEID is 0, save the hostname to a file
5144
if [ "${SLURM_NODEID}" = "0" ]; then
5245
mkdir -p ${log_dir}/hostnames/
5346
echo $(hostname) > ${log_dir}/hostnames/${role}_${instance_id}.txt
5447
echo "hostname saved to ${log_dir}/hostnames/${role}_${instance_id}.txt"
5548
fi
5649

57-
#check if nsys is enabled
50+
nsys_prefix=""
5851
if [ "${enable_nsys}" != "true" ]; then
5952
echo "nsys is not enabled, start normal flow"
60-
trtllm-llmapi-launch ${numa_bind_cmd} trtllm-serve ${model_path} --host $(hostname) --port ${port} --extra_llm_api_options ${config_file}
6153
else
62-
nsys_prefix=""
6354
nsys_file=${log_dir}/nsys_worker_proc_${role}_${instance_id}_${SLURM_PROCID}
6455
export TLLM_PROFILE_RECORD_GC=1
6556
export TLLM_NVTX_DEBUG=1
66-
nsys_prefix="nsys profile -e \"NSYS_MPI_STORE_TEAMS_PER_RANK=1\" -o ${nsys_file} -f true -t cuda,nvtx,python-gil -c cudaProfilerApi --cuda-graph-trace node --capture-range-end=stop --gpu-metrics-devices=none"
67-
if [ "${role}" = "GEN" ]; then
68-
export TLLM_PROFILE_START_STOP=200-250
69-
echo "nsys is enabled on gen_gpus"
70-
elif [ "${role}" = "CTX" ]; then
71-
export TLLM_PROFILE_START_STOP=10-30
72-
echo "nsys is enabled on ctx_gpus"
73-
fi
74-
${nsys_prefix} trtllm-llmapi-launch ${numa_bind_cmd} \
75-
trtllm-serve ${model_path} \
76-
--host $(hostname) --port ${port} \
77-
--extra_llm_api_options ${config_file}
57+
export NSYS_MPI_STORE_TEAMS_PER_RANK=1
58+
export TLLM_PROFILE_START_STOP=${profile_range}
59+
echo "nsys is enabled on ${role} GPUs, TLLM_PROFILE_START_STOP=${profile_range}"
60+
nsys_prefix="nsys profile -o ${nsys_file} -f true -t cuda,nvtx,python-gil -c cudaProfilerApi --cuda-graph-trace node --capture-range-end=stop --gpu-metrics-devices=none"
7861
fi
62+
63+
${nsys_prefix} trtllm-llmapi-launch ${numa_bind_cmd} \
64+
trtllm-serve ${model_path} \
65+
--host $(hostname) --port ${port} \
66+
--extra_llm_api_options ${config_file}

examples/disaggregated/slurm/benchmark/submit.py

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import shutil
77
import subprocess
88
import sys
9+
from datetime import datetime
910

1011
import yaml
1112

@@ -22,6 +23,10 @@ def parse_args():
2223
'--dir',
2324
type=str,
2425
help='Directory containing YAML configuration files')
26+
group.add_argument('--log-dir',
27+
type=str,
28+
default=None,
29+
help='Log directory')
2530
return parser.parse_args()
2631

2732

@@ -45,9 +50,11 @@ def calculate_nodes(world_size, num_servers, gpus_per_node):
4550
return (world_size + gpus_per_node - 1) // gpus_per_node * num_servers
4651

4752

48-
def submit_job(config):
53+
def submit_job(config, log_dir):
4954
# Extract configurations
5055
slurm_config = config['slurm']
56+
slurm_config.setdefault('extra_args', '')
57+
5158
hw_config = config['hardware']
5259
env_config = config['environment']
5360

@@ -71,6 +78,11 @@ def submit_job(config):
7178
env_config.setdefault('worker_env_var', '')
7279
env_config.setdefault('server_env_var', '')
7380

81+
profiling_config = config.get('profiling', {})
82+
profiling_config.setdefault('nsys_on', False)
83+
profiling_config.setdefault('ctx_profile_range', '10-30')
84+
profiling_config.setdefault('gen_profile_range', '200-250')
85+
7486
# Get number of servers from config
7587
ctx_num = hw_config['num_ctx_servers']
7688
gen_num = hw_config['num_gen_servers']
@@ -101,29 +113,35 @@ def submit_job(config):
101113
gen_enable_attention_dp = config['worker_config']['gen'][
102114
'enable_attention_dp']
103115

104-
# Create base log directory path
105-
log_base = os.path.join(env_config['work_dir'], f"{isl}-{osl}")
106-
107-
# Get eplb num_slots for gen worker
108-
load_balancer_config = config['worker_config']['gen'].get(
109-
'moe_config', {}).get('load_balancer', {})
110-
if isinstance(load_balancer_config, str):
111-
with open(load_balancer_config, 'r') as f:
112-
load_balancer_config = yaml.safe_load(f)
113-
eplb_num_slots = load_balancer_config.get('num_slots', 0)
114-
115-
# Determine directory suffix based on attention_dp
116-
if gen_enable_attention_dp:
117-
dir_suffix = f"ctx{ctx_num}_gen{gen_num}_dep{gen_tp_size}_batch{gen_batch_size}_eplb{eplb_num_slots}_mtp{mtp_size}"
118-
else:
119-
dir_suffix = f"ctx{ctx_num}_gen{gen_num}_tep{gen_tp_size}_batch{gen_batch_size}_eplb{eplb_num_slots}_mtp{mtp_size}"
116+
if log_dir is None:
117+
# Create base log directory path
118+
date_prefix = datetime.now().strftime("%Y%m%d")
119+
log_base = os.path.join(env_config['work_dir'],
120+
f"{date_prefix}/{isl}-{osl}")
121+
122+
# Get eplb num_slots for gen worker
123+
load_balancer_config = config['worker_config']['gen'].get(
124+
'moe_config', {}).get('load_balancer', {})
125+
if isinstance(load_balancer_config, str):
126+
with open(load_balancer_config, 'r') as f:
127+
load_balancer_config = yaml.safe_load(f)
128+
eplb_num_slots = load_balancer_config.get('num_slots', 0)
129+
130+
# Determine directory suffix based on attention_dp
131+
if gen_enable_attention_dp:
132+
dir_suffix = f"ctx{ctx_num}_gen{gen_num}_dep{gen_tp_size}_batch{gen_batch_size}_eplb{eplb_num_slots}_mtp{mtp_size}"
133+
else:
134+
dir_suffix = f"ctx{ctx_num}_gen{gen_num}_tep{gen_tp_size}_batch{gen_batch_size}_eplb{eplb_num_slots}_mtp{mtp_size}"
135+
136+
# Create full log directory path
137+
log_dir = os.path.join(log_base, dir_suffix)
120138

121-
# Create full log directory path
122-
log_dir = os.path.join(log_base, dir_suffix)
123139
# Remove existing directory if it exists
124140
if os.path.exists(log_dir):
141+
print(f"[WARNING] Removing existing log directory: {log_dir}")
125142
shutil.rmtree(log_dir)
126143
os.makedirs(log_dir)
144+
print(f"Log will be saved to: {log_dir}")
127145

128146
# Setup config file paths and save worker configs
129147
ctx_config_path = os.path.join(log_dir, 'ctx_config.yaml')
@@ -135,14 +153,14 @@ def submit_job(config):
135153
cmd = [
136154
'sbatch',
137155
f'--partition={slurm_config["partition"]}',
138-
f'--gres=gpu:{hw_config["gpus_per_node"]}',
139156
f'--account={slurm_config["account"]}',
140157
f'--time={slurm_config["job_time"]}',
141158
f'--job-name={slurm_config["job_name"]}',
142159
f'--nodes={total_nodes}',
143160
f'--ntasks={total_tasks}',
144161
f'--ntasks-per-node={hw_config["gpus_per_node"]}',
145162
f'--segment={total_nodes}',
163+
*([arg for arg in slurm_config['extra_args'].split() if arg]),
146164
slurm_config['script_file'],
147165
# Hardware configuration
148166
str(hw_config['gpus_per_node']),
@@ -182,7 +200,9 @@ def submit_job(config):
182200
env_config['trtllm_wheel_path'],
183201

184202
# Profiling
185-
str(config['profiling']['nsys_on']).lower(),
203+
str(profiling_config['nsys_on']).lower(),
204+
profiling_config['ctx_profile_range'],
205+
profiling_config['gen_profile_range'],
186206

187207
# Accuracy evaluation
188208
str(config['accuracy']['enable_accuracy_test']).lower(),
@@ -226,11 +246,11 @@ def main():
226246

227247
# Process each config file
228248
for config_file in config_files:
229-
print(f"\nProcessing: {config_file}")
249+
print(f"Processing: {config_file}")
230250
try:
231251
config = load_config(config_file)
232-
submit_job(config)
233-
print(f"Successfully submitted job for: {config_file}")
252+
submit_job(config, args.log_dir)
253+
print(f"Successfully submitted job for: {config_file}\n")
234254
except Exception as e:
235255
print(f"Error processing {config_file}: {e}", file=sys.stderr)
236256
# Continue processing other files even if one fails

examples/wide_ep/slurm_scripts/config.yaml

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,27 @@ slurm:
55
account: "<account>"
66
job_time: "02:00:00"
77
job_name: "<job_name>"
8+
extra_args: "" # Cluster specific arguments, e.g. "--gres=gpu:4 --exclude=node1,node2"
89
numa_bind: true # Only enable for GB200 NVL72
910

10-
# Hardware Configuration
11-
hardware:
12-
gpus_per_node: 4 # Modify this with your hardware configuration
13-
num_ctx_servers: 2 # Number of context servers
14-
num_gen_servers: 1 # Number of generation servers
15-
1611
# Benchmark Mode
1712
benchmark:
1813
mode: "e2e" # Options: e2e, gen_only
1914
use_nv_sa_benchmark: false # Whether to use NVIDIA SA benchmark script
20-
multi_round: 1 # Number of benchmark rounds
15+
multi_round: 8 # Number of benchmark rounds
2116
benchmark_ratio: 0.8 # Benchmark ratio
2217
streaming: true # Enable streaming mode
2318
concurrency_list: "1024"
2419
input_length: 8196 # Input sequence length
2520
output_length: 1024 # Output sequence length
2621
dataset_file: "<dataset_file>"
2722

23+
# Hardware Configuration
24+
hardware:
25+
gpus_per_node: 4 # Modify this with your hardware configuration
26+
num_ctx_servers: 1 # Number of context servers
27+
num_gen_servers: 1 # Number of generation servers
28+
2829
# Environment Configuration
2930
environment:
3031
container_mount: "<container_mount>" # Format: path1:path1,path2:path2
@@ -34,24 +35,24 @@ environment:
3435
build_wheel: false # Don't build the wheel when launching multiple jobs
3536
trtllm_wheel_path: "" # Path to pre-built TensorRT-LLM wheel. If provided, install from this wheel instead
3637
work_dir: "<full_path_to_work_dir>"
37-
worker_env_var: "TLLM_LOG_LEVEL=INFO TRTLLM_SERVER_DISABLE_GC=1 TRTLLM_WORKER_DISABLE_GC=1" # Environment variables for workers
38-
server_env_var: "" # Environment variables for server
38+
worker_env_var: "TLLM_LOG_LEVEL=INFO TRTLLM_SERVER_DISABLE_GC=1 TRTLLM_WORKER_DISABLE_GC=1 TRTLLM_ENABLE_PDL=1 ENROOT_ALLOW_DEV=yes"
39+
server_env_var: "TRTLLM_SERVER_DISABLE_GC=1"
3940

4041
# Profiling Configuration
4142
profiling:
4243
nsys_on: false # Set to true to enable profiling
44+
ctx_profile_range: "10-30" # Set TLLM_PROFILE_START_STOP for ctx workers
45+
gen_profile_range: "200-250" # Set TLLM_PROFILE_START_STOP for gen workers
4346

4447
# Accuracy Configuration
4548
accuracy:
4649
enable_accuracy_test: false # Set to true to enable accuracy evaluation
4750
model: "local-completions" # Model type for lm_eval
4851
tasks: "gsm8k" # Evaluation tasks (comma-separated)
49-
model_args_extra: "num_concurrent=512,max_retries=3,tokenized_requests=False,timeout=1200,max_gen_toks=256,max_length=512" # Extra model arguments for lm_eval
52+
model_args_extra: "num_concurrent=512,max_retries=3,tokenized_requests=false,timeout=1200,max_gen_toks=256,max_length=4096" # Extra model arguments for lm_eval
5053

51-
# Worker Configuration
5254
worker_config:
5355
gen:
54-
enable_layerwise_nvtx_marker: true
5556
tensor_parallel_size: 32
5657
moe_expert_parallel_size: 32
5758
enable_attention_dp: true
@@ -97,7 +98,6 @@ worker_config:
9798
decoding_type: MTP
9899
num_nextn_predict_layers: 3
99100
ctx:
100-
enable_layerwise_nvtx_marker: true
101101
max_batch_size: 1
102102
max_num_tokens: 8448
103103
max_seq_len: 8212

0 commit comments

Comments
 (0)