diff --git a/.buildkite/models/google_gemma-3-27b-it.yml b/.buildkite/models/google_gemma-3-27b-it.yml index 46089c5ad..531d29c53 100644 --- a/.buildkite/models/google_gemma-3-27b-it.yml +++ b/.buildkite/models/google_gemma-3-27b-it.yml @@ -8,7 +8,7 @@ steps: commands: - | .buildkite/scripts/run_in_docker.sh \ - bash -c 'SKIP_JAX_PRECOMPILE=1 VLLM_XLA_CHECK_RECOMPILATION=0 python3 /workspace/tpu_inference/examples/offline_inference.py --model=google/gemma-3-27b-it --tensor_parallel_size=8 --task=generate --max_model_len=1024 --max_num_seqs=1' + bash -c 'VLLM_XLA_CHECK_RECOMPILATION=0 python3 /workspace/tpu_inference/examples/offline_inference.py --model=google/gemma-3-27b-it --tensor_parallel_size=8 --task=generate --max_model_len=1024 --max_num_seqs=1' - label: "Record unit test result for google/gemma-3-27b-it" key: "record_google_gemma-3-27b-it_UnitTest" depends_on: "google_gemma-3-27b-it_UnitTest" diff --git a/examples/disagg/run_disagg_multi_host.sh b/examples/disagg/run_disagg_multi_host.sh index 62dac10fc..41b4ca059 100755 --- a/examples/disagg/run_disagg_multi_host.sh +++ b/examples/disagg/run_disagg_multi_host.sh @@ -63,7 +63,6 @@ for ((i=0; i /root/logs/prefill.txt 2>&1 &" set +x @@ -137,7 +137,6 @@ for ((i=0; i /root/logs/decode.txt 2>&1 &" set +x diff --git a/examples/disagg/run_disagg_single_host.sh b/examples/disagg/run_disagg_single_host.sh index efd8fef97..7d69321ff 100755 --- a/examples/disagg/run_disagg_single_host.sh +++ b/examples/disagg/run_disagg_single_host.sh @@ -45,13 +45,13 @@ for i in $(seq 0 $((NUM_PREFILL_INSTANCES-1))); do \ TPU_KV_TRANSFER_PORT=$KV_PORT \ TPU_SIDE_CHANNEL_PORT=$SIDE_PORT \ - SKIP_JAX_PRECOMPILE=1 \ \ vllm serve $MODEL \ --port $PORT \ --gpu-memory-utilization 0.2 \ --tensor-parallel-size $PREFILLER_TP_SIZE \ --kv-transfer-config "{\"kv_connector\":\"TPUConnector\",\"kv_connector_module_path\":\"tpu_inference.distributed.tpu_connector\",\"kv_role\":\"kv_producer\"}" \ + --enforce-eager \ > $HOME/logs/prefill_$i.txt 2>&1 & PREFILL_HOSTS+=("localhost") @@ -72,13 +72,13 @@ for i in $(seq 0 $((NUM_DECODE_INSTANCES-1))); do \ TPU_KV_TRANSFER_PORT=$KV_PORT \ TPU_SIDE_CHANNEL_PORT=$SIDE_PORT \ - SKIP_JAX_PRECOMPILE=1 \ \ vllm serve $MODEL \ --port $PORT \ --gpu-memory-utilization 0.2 \ --tensor-parallel-size $DECODER_TP_SIZE \ --kv-transfer-config "{\"kv_connector\":\"TPUConnector\",\"kv_connector_module_path\":\"tpu_inference.distributed.tpu_connector\",\"kv_role\":\"kv_consumer\"}" \ + --enforce-eager \ > $HOME/logs/decode_$i.txt 2>&1 & DECODE_HOSTS+=("localhost") diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 293560767..7fc5e90ea 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os - import vllm.envs as envs from vllm import LLM, EngineArgs from vllm.utils.argparse_utils import FlexibleArgumentParser @@ -17,6 +15,9 @@ def create_parser(): parser.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct") parser.set_defaults(max_model_len=1024) + # Skip long warmup for local simple test. + parser.set_defaults(enforce_eager=True) + # Add sampling params sampling_group = parser.add_argument_group("Sampling parameters") sampling_group.add_argument("--max-tokens", type=int) @@ -103,9 +104,6 @@ def main(args: dict): if __name__ == "__main__": - # Skip long warmup for local simple test. - os.environ['SKIP_JAX_PRECOMPILE'] = '1' - parser = create_parser() args: dict = vars(parser.parse_args()) diff --git a/examples/offline_lora_inference.py b/examples/offline_lora_inference.py index 386c74e5e..d8f14b2bf 100644 --- a/examples/offline_lora_inference.py +++ b/examples/offline_lora_inference.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os import time import vllm.envs as envs @@ -20,6 +19,9 @@ def create_parser(): parser.set_defaults(enable_lora=True) parser.set_defaults(max_lora_rank=8) + # Skip long warmup for local simple test. + parser.set_defaults(enforce_eager=True) + # Add sampling params sampling_group = parser.add_argument_group("Sampling parameters") sampling_group.add_argument("--max-tokens", type=int, default=16) @@ -76,9 +78,6 @@ def main(args: dict): if __name__ == "__main__": - # Skip long warmup for local simple test. - os.environ['SKIP_JAX_PRECOMPILE'] = '1' - parser = create_parser() args: dict = vars(parser.parse_args()) diff --git a/tests/e2e/benchmarking/mm_bench.sh b/tests/e2e/benchmarking/mm_bench.sh index 063c78625..27aab3446 100755 --- a/tests/e2e/benchmarking/mm_bench.sh +++ b/tests/e2e/benchmarking/mm_bench.sh @@ -91,7 +91,7 @@ checkThroughputAndRouge() { } echo "Spinning up the vLLM server..." -(SKIP_JAX_PRECOMPILE=1 VLLM_XLA_CHECK_RECOMPILATION=0 vllm serve "$model_name" --max-model-len "$max_model_len" --max-num-seqs "$max_num_seqs" --disable-log-requests --max-num-batched-tokens "$max_batched_tokens" 2>&1 | tee -a "$LOG_FILE") & +(VLLM_XLA_CHECK_RECOMPILATION=0 vllm serve "$model_name" --max-model-len "$max_model_len" --max-num-seqs "$max_num_seqs" --disable-log-requests --max-num-batched-tokens "$max_batched_tokens" --enforce-eager 2>&1 | tee -a "$LOG_FILE") & # Run a busy loop to block until the server is ready to receive requests diff --git a/tests/e2e/test_multi_modal_inference.py b/tests/e2e/test_multi_modal_inference.py index c1d2bda77..60ca47b19 100644 --- a/tests/e2e/test_multi_modal_inference.py +++ b/tests/e2e/test_multi_modal_inference.py @@ -24,7 +24,6 @@ def test_multi_modal_inference(monkeypatch): """ Runs multi-modal inference and verifies the output. """ - os.environ['SKIP_JAX_PRECOMPILE'] = '1' # Skip warmup to save time. os.environ[ 'VLLM_XLA_CHECK_RECOMPILATION'] = '0' # Allow compilation during execution. @@ -65,6 +64,7 @@ def test_multi_modal_inference(monkeypatch): "fps": 1, }, limit_mm_per_prompt={modality: 1}, + enforce_eager=True, # Skip warmup to save time. ) engine_args = asdict(engine_args) llm = LLM(**engine_args) diff --git a/tpu_inference/runner/compilation_manager.py b/tpu_inference/runner/compilation_manager.py index c50c4bc86..5ded68efb 100644 --- a/tpu_inference/runner/compilation_manager.py +++ b/tpu_inference/runner/compilation_manager.py @@ -1,4 +1,3 @@ -import os import time from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple @@ -67,8 +66,7 @@ def _run_compilation(self, name: str, fn: Callable, *args, logger.info("Compilation finished in %.2f [secs].", end - start) def capture_model(self) -> None: - if os.getenv("SKIP_JAX_PRECOMPILE", - False) or self.runner.model_config.enforce_eager: + if self.runner.model_config.enforce_eager: return logger.info("Precompile all the subgraphs with possible input shapes.")