Skip to content

Commit 4faf80b

Browse files
committed
Remove SKIP_JAX_PRECOMPILE
- Same functionality can be achieve with vllm argument --enforce-eager - It is better to remove duplicate configs to avoid confusion from users Signed-off-by: Kyuyeun Kim <kyuyeunk@google.com>
1 parent 0c66fde commit 4faf80b

File tree

8 files changed

+14
-19
lines changed

8 files changed

+14
-19
lines changed

.buildkite/models/google_gemma-3-27b-it.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ steps:
88
commands:
99
- |
1010
.buildkite/scripts/run_in_docker.sh \
11-
bash -c 'SKIP_JAX_PRECOMPILE=1 VLLM_XLA_CHECK_RECOMPILATION=0 python3 /workspace/tpu_inference/examples/offline_inference.py --model=google/gemma-3-27b-it --tensor_parallel_size=8 --task=generate --max_model_len=1024 --max_num_seqs=1'
11+
bash -c 'VLLM_XLA_CHECK_RECOMPILATION=0 python3 /workspace/tpu_inference/examples/offline_inference.py --model=google/gemma-3-27b-it --tensor_parallel_size=8 --task=generate --max_model_len=1024 --max_num_seqs=1'
1212
- label: "Record unit test result for google/gemma-3-27b-it"
1313
key: "record_google_gemma-3-27b-it_UnitTest"
1414
depends_on: "google_gemma-3-27b-it_UnitTest"

examples/disagg/run_disagg_multi_host.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ for ((i=0; i<NUM_HOSTS_PER_INSTANCE; i++)); do
6363
-e TPU_KV_TRANSFER_PORT="${KV_PORT}" \
6464
-e TPU_SIDE_CHANNEL_PORT="${SIDE_PORT}" \
6565
-e RAY_DEDUP_LOGS="0" \
66-
-e SKIP_JAX_PRECOMPILE="1" \
6766
\
6867
-e TPU_CHIPS_PER_PROCESS_BOUNDS="1,1,1" \
6968
-e TPU_PROCESS_BOUNDS="2,2,1" \
@@ -95,6 +94,7 @@ docker exec node-0 /bin/bash -c \
9594
--gpu-memory-utilization 0.3 \
9695
--tensor-parallel-size 4 \
9796
--kv-transfer-config '{\"kv_connector\":\"TPUConnector\",\"kv_connector_module_path\":\"tpu_inference.distributed.tpu_connector\",\"kv_role\":\"kv_producer\"}' \
97+
--enforce-eager \
9898
> /root/logs/prefill.txt 2>&1 &"
9999
set +x
100100

@@ -137,7 +137,6 @@ for ((i=0; i<NUM_HOSTS_PER_INSTANCE; i++)); do
137137
-e TPU_KV_TRANSFER_PORT="${KV_PORT}" \
138138
-e TPU_SIDE_CHANNEL_PORT="${SIDE_PORT}" \
139139
-e RAY_DEDUP_LOGS="0" \
140-
-e SKIP_JAX_PRECOMPILE="1" \
141140
\
142141
-e TPU_CHIPS_PER_PROCESS_BOUNDS="1,1,1" \
143142
-e TPU_PROCESS_BOUNDS="2,2,1" \
@@ -169,5 +168,6 @@ docker exec node-20 /bin/bash -c \
169168
--gpu-memory-utilization 0.3 \
170169
--tensor-parallel-size 4 \
171170
--kv-transfer-config '{\"kv_connector\":\"TPUConnector\",\"kv_connector_module_path\":\"tpu_inference.distributed.tpu_connector\",\"kv_role\":\"kv_consumer\"}' \
171+
--enforce-eager \
172172
> /root/logs/decode.txt 2>&1 &"
173173
set +x

examples/disagg/run_disagg_single_host.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,13 @@ for i in $(seq 0 $((NUM_PREFILL_INSTANCES-1))); do
4545
\
4646
TPU_KV_TRANSFER_PORT=$KV_PORT \
4747
TPU_SIDE_CHANNEL_PORT=$SIDE_PORT \
48-
SKIP_JAX_PRECOMPILE=1 \
4948
\
5049
vllm serve $MODEL \
5150
--port $PORT \
5251
--gpu-memory-utilization 0.2 \
5352
--tensor-parallel-size $PREFILLER_TP_SIZE \
5453
--kv-transfer-config "{\"kv_connector\":\"TPUConnector\",\"kv_connector_module_path\":\"tpu_inference.distributed.tpu_connector\",\"kv_role\":\"kv_producer\"}" \
54+
--enforce-eager \
5555
> $HOME/logs/prefill_$i.txt 2>&1 &
5656

5757
PREFILL_HOSTS+=("localhost")
@@ -72,13 +72,13 @@ for i in $(seq 0 $((NUM_DECODE_INSTANCES-1))); do
7272
\
7373
TPU_KV_TRANSFER_PORT=$KV_PORT \
7474
TPU_SIDE_CHANNEL_PORT=$SIDE_PORT \
75-
SKIP_JAX_PRECOMPILE=1 \
7675
\
7776
vllm serve $MODEL \
7877
--port $PORT \
7978
--gpu-memory-utilization 0.2 \
8079
--tensor-parallel-size $DECODER_TP_SIZE \
8180
--kv-transfer-config "{\"kv_connector\":\"TPUConnector\",\"kv_connector_module_path\":\"tpu_inference.distributed.tpu_connector\",\"kv_role\":\"kv_consumer\"}" \
81+
--enforce-eager \
8282
> $HOME/logs/decode_$i.txt 2>&1 &
8383

8484
DECODE_HOSTS+=("localhost")

examples/offline_inference.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
import os
5-
64
import vllm.envs as envs
75
from vllm import LLM, EngineArgs
86
from vllm.utils.argparse_utils import FlexibleArgumentParser
@@ -17,6 +15,9 @@ def create_parser():
1715
parser.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
1816
parser.set_defaults(max_model_len=1024)
1917

18+
# Skip long warmup for local simple test.
19+
parser.set_defaults(enforce_eager=True)
20+
2021
# Add sampling params
2122
sampling_group = parser.add_argument_group("Sampling parameters")
2223
sampling_group.add_argument("--max-tokens", type=int)
@@ -103,9 +104,6 @@ def main(args: dict):
103104

104105

105106
if __name__ == "__main__":
106-
# Skip long warmup for local simple test.
107-
os.environ['SKIP_JAX_PRECOMPILE'] = '1'
108-
109107
parser = create_parser()
110108
args: dict = vars(parser.parse_args())
111109

examples/offline_lora_inference.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
import os
54
import time
65

76
import vllm.envs as envs
@@ -20,6 +19,9 @@ def create_parser():
2019
parser.set_defaults(enable_lora=True)
2120
parser.set_defaults(max_lora_rank=8)
2221

22+
# Skip long warmup for local simple test.
23+
parser.set_defaults(enforce_eager=True)
24+
2325
# Add sampling params
2426
sampling_group = parser.add_argument_group("Sampling parameters")
2527
sampling_group.add_argument("--max-tokens", type=int, default=16)
@@ -76,9 +78,6 @@ def main(args: dict):
7678

7779

7880
if __name__ == "__main__":
79-
# Skip long warmup for local simple test.
80-
os.environ['SKIP_JAX_PRECOMPILE'] = '1'
81-
8281
parser = create_parser()
8382
args: dict = vars(parser.parse_args())
8483

tests/e2e/benchmarking/mm_bench.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ checkThroughputAndRouge() {
9191
}
9292

9393
echo "Spinning up the vLLM server..."
94-
(SKIP_JAX_PRECOMPILE=1 VLLM_XLA_CHECK_RECOMPILATION=0 vllm serve "$model_name" --max-model-len "$max_model_len" --max-num-seqs "$max_num_seqs" --disable-log-requests --max-num-batched-tokens "$max_batched_tokens" 2>&1 | tee -a "$LOG_FILE") &
94+
(VLLM_XLA_CHECK_RECOMPILATION=0 vllm serve "$model_name" --max-model-len "$max_model_len" --max-num-seqs "$max_num_seqs" --disable-log-requests --max-num-batched-tokens "$max_batched_tokens" --enforce-eager 2>&1 | tee -a "$LOG_FILE") &
9595

9696

9797
# Run a busy loop to block until the server is ready to receive requests

tests/e2e/test_multi_modal_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ def test_multi_modal_inference(monkeypatch):
2424
"""
2525
Runs multi-modal inference and verifies the output.
2626
"""
27-
os.environ['SKIP_JAX_PRECOMPILE'] = '1' # Skip warmup to save time.
2827
os.environ[
2928
'VLLM_XLA_CHECK_RECOMPILATION'] = '0' # Allow compilation during execution.
3029

@@ -65,6 +64,7 @@ def test_multi_modal_inference(monkeypatch):
6564
"fps": 1,
6665
},
6766
limit_mm_per_prompt={modality: 1},
67+
enforce_eager=True, # Skip warmup to save time.
6868
)
6969
engine_args = asdict(engine_args)
7070
llm = LLM(**engine_args)

tpu_inference/runner/compilation_manager.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import os
21
import time
32
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
43

@@ -67,8 +66,7 @@ def _run_compilation(self, name: str, fn: Callable, *args,
6766
logger.info("Compilation finished in %.2f [secs].", end - start)
6867

6968
def capture_model(self) -> None:
70-
if os.getenv("SKIP_JAX_PRECOMPILE",
71-
False) or self.runner.model_config.enforce_eager:
69+
if self.runner.model_config.enforce_eager:
7270
return
7371
logger.info("Precompile all the subgraphs with possible input shapes.")
7472

0 commit comments

Comments
 (0)