Skip to content

Commit 70ebcc9

Browse files
[Bug fix] Pathways HBM calculation with DP enabled + Async scheduler precompilation (#1072)
Signed-off-by: wenxindongwork <wenxindong@google.com>
1 parent 825acd1 commit 70ebcc9

File tree

3 files changed

+36
-30
lines changed

3 files changed

+36
-30
lines changed

tests/test_utils.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -75,28 +75,36 @@ def test_hbm_usage_bytes_pathways_enabled(mock_devices, mock_live_arrays):
7575
mock_device2 = MagicMock()
7676
devices = [mock_device1, mock_device2]
7777

78-
# Create mock arrays with sharding
78+
# Create mock device buffers
79+
mock_buffer1_dev1 = MagicMock()
80+
mock_buffer1_dev1.device = mock_device1
81+
mock_buffer1_dev1.nbytes = 2000 # 2000 bytes on device1
82+
83+
mock_buffer1_dev2 = MagicMock()
84+
mock_buffer1_dev2.device = mock_device2
85+
mock_buffer1_dev2.nbytes = 2000 # 2000 bytes on device2
86+
87+
mock_buffer2_dev1 = MagicMock()
88+
mock_buffer2_dev1.device = mock_device1
89+
mock_buffer2_dev1.nbytes = 1000 # 1000 bytes on device1
90+
91+
# Create mock arrays with device buffers
7992
mock_array1 = MagicMock()
80-
mock_array1.dtype.itemsize = 4 # float32
81-
mock_array1.size = 1000 # 1000 elements
82-
mock_array1.sharding.device_set = {mock_device1, mock_device2
83-
} # Sharded across 2 devices
93+
mock_array1.device_buffers = [mock_buffer1_dev1, mock_buffer1_dev2]
8494

8595
mock_array2 = MagicMock()
86-
mock_array2.dtype.itemsize = 2 # float16
87-
mock_array2.size = 500 # 500 elements
88-
mock_array2.sharding.device_set = {mock_device1} # Only on device1
96+
mock_array2.device_buffers = [mock_buffer2_dev1]
8997

9098
mock_live_arrays.return_value = [mock_array1, mock_array2]
9199

92100
usage = hbm_usage_bytes(devices)
93101

94102
# Expected calculations:
95-
# Array1: 4 bytes * 1000 elements / 2 devices = 2000 bytes per device
96-
# Array2: 2 bytes * 500 elements / 1 device = 1000 bytes on device1 only
97-
# Device1: 2000 + 1000 = 3000 bytes
98-
# Device2: 2000 + 0 = 2000 bytes
99-
# hbm_limit = 33550237184 (hardcoded in the function)
103+
# Array1: 2000 bytes on device1, 2000 bytes on device2
104+
# Array2: 1000 bytes on device1
105+
# Device1 total: 2000 + 1000 = 3000 bytes
106+
# Device2 total: 2000 + 0 = 2000 bytes
107+
# hbm_limit = 95 * GBYTES for TPU v5p
100108
expected_usage = [(3000, 95 * GBYTES), (2000, 95 * GBYTES)]
101109
assert usage == expected_usage
102110

@@ -127,7 +135,7 @@ def test_hbm_usage_gb_pathways_disabled():
127135
@patch("jax.devices")
128136
def test_hbm_usage_bytes_pathways_no_arrays(mock_devices, mock_live_arrays):
129137
"""Tests hbm_usage_bytes when VLLM_TPU_USING_PATHWAYS is True but no live arrays."""
130-
# Mock TPU v5e devices
138+
# Mock TPU v6e devices
131139
mock_jax_device = MagicMock()
132140
mock_jax_device.device_kind = "TPU v6e"
133141
mock_devices.return_value = [mock_jax_device]
@@ -141,7 +149,8 @@ def test_hbm_usage_bytes_pathways_no_arrays(mock_devices, mock_live_arrays):
141149

142150
usage = hbm_usage_bytes(devices)
143151

144-
# No arrays means no memory usage
152+
# No arrays means no memory usage, defaultdict returns 0 for missing keys
153+
# HBM limit for TPU v6e is 32 GB
145154
expected_usage = [(0, 32 * GBYTES), (0, 32 * GBYTES)]
146155
assert usage == expected_usage
147156

tpu_inference/runner/compilation_manager.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -202,20 +202,21 @@ def _precompile_substitute_placeholder_token(self) -> None:
202202
"""
203203

204204
for num_tokens in self.runner.num_tokens_paddings:
205-
padded_token_in_tpu_cur_input_indices = np.zeros((num_tokens, ),
206-
dtype=np.int32)
207-
padded_token_in_tpu_pre_next_tokens_indices = np.zeros(
208-
(num_tokens, ), dtype=jnp.int32)
209-
(padded_token_in_tpu_cur_input_indices,
210-
padded_token_in_tpu_pre_next_tokens_indices) = device_array(
211-
self.runner.mesh,
212-
(padded_token_in_tpu_cur_input_indices,
213-
padded_token_in_tpu_pre_next_tokens_indices))
214205
dp_sharding = NamedSharding(
215206
self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, )
216207
) if self.runner.vllm_config.sharding_config.total_dp_size > 1 else None
217208

218209
for num_reqs in self.runner.num_reqs_paddings:
210+
padded_token_in_tpu_cur_input_indices = np.zeros(
211+
(num_tokens, ), dtype=np.int32)
212+
padded_token_in_tpu_pre_next_tokens_indices = np.zeros(
213+
(num_tokens, ), dtype=jnp.int32)
214+
(padded_token_in_tpu_cur_input_indices,
215+
padded_token_in_tpu_pre_next_tokens_indices) = device_array(
216+
self.runner.mesh,
217+
(padded_token_in_tpu_cur_input_indices,
218+
padded_token_in_tpu_pre_next_tokens_indices))
219+
219220
input_ids = self._create_dummy_tensor((num_tokens, ),
220221
jnp.int32, dp_sharding)
221222
# Need align to the sampling output

tpu_inference/utils.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,12 +132,8 @@ def pathways_hbm_usage_gb(devices: Any) -> List[Tuple[float, float]]:
132132
hbm_used = defaultdict(int)
133133
hbm_limit = get_device_hbm_limit()
134134
for array in live_arrays:
135-
assert hasattr(array, 'sharding') and hasattr(
136-
array.sharding, 'device_set'
137-
), "This function must not be called within jax tracer (e.g. jit, vmap, grad)"
138-
for device in array.sharding.device_set:
139-
hbm_used[device] += array.dtype.itemsize * array.size // len(
140-
array.sharding.device_set)
135+
for buffer in array.device_buffers:
136+
hbm_used[buffer.device] += buffer.nbytes
141137
return [(hbm_used[device], hbm_limit) for device in devices]
142138

143139

0 commit comments

Comments
 (0)