@@ -75,28 +75,36 @@ def test_hbm_usage_bytes_pathways_enabled(mock_devices, mock_live_arrays):
7575 mock_device2 = MagicMock ()
7676 devices = [mock_device1 , mock_device2 ]
7777
78- # Create mock arrays with sharding
78+ # Create mock device buffers
79+ mock_buffer1_dev1 = MagicMock ()
80+ mock_buffer1_dev1 .device = mock_device1
81+ mock_buffer1_dev1 .nbytes = 2000 # 2000 bytes on device1
82+
83+ mock_buffer1_dev2 = MagicMock ()
84+ mock_buffer1_dev2 .device = mock_device2
85+ mock_buffer1_dev2 .nbytes = 2000 # 2000 bytes on device2
86+
87+ mock_buffer2_dev1 = MagicMock ()
88+ mock_buffer2_dev1 .device = mock_device1
89+ mock_buffer2_dev1 .nbytes = 1000 # 1000 bytes on device1
90+
91+ # Create mock arrays with device buffers
7992 mock_array1 = MagicMock ()
80- mock_array1 .dtype .itemsize = 4 # float32
81- mock_array1 .size = 1000 # 1000 elements
82- mock_array1 .sharding .device_set = {mock_device1 , mock_device2
83- } # Sharded across 2 devices
93+ mock_array1 .device_buffers = [mock_buffer1_dev1 , mock_buffer1_dev2 ]
8494
8595 mock_array2 = MagicMock ()
86- mock_array2 .dtype .itemsize = 2 # float16
87- mock_array2 .size = 500 # 500 elements
88- mock_array2 .sharding .device_set = {mock_device1 } # Only on device1
96+ mock_array2 .device_buffers = [mock_buffer2_dev1 ]
8997
9098 mock_live_arrays .return_value = [mock_array1 , mock_array2 ]
9199
92100 usage = hbm_usage_bytes (devices )
93101
94102 # Expected calculations:
95- # Array1: 4 bytes * 1000 elements / 2 devices = 2000 bytes per device
96- # Array2: 2 bytes * 500 elements / 1 device = 1000 bytes on device1 only
97- # Device1: 2000 + 1000 = 3000 bytes
98- # Device2: 2000 + 0 = 2000 bytes
99- # hbm_limit = 33550237184 (hardcoded in the function)
103+ # Array1: 2000 bytes on device1, 2000 bytes on device2
104+ # Array2: 1000 bytes on device1
105+ # Device1 total : 2000 + 1000 = 3000 bytes
106+ # Device2 total : 2000 + 0 = 2000 bytes
107+ # hbm_limit = 95 * GBYTES for TPU v5p
100108 expected_usage = [(3000 , 95 * GBYTES ), (2000 , 95 * GBYTES )]
101109 assert usage == expected_usage
102110
@@ -127,7 +135,7 @@ def test_hbm_usage_gb_pathways_disabled():
127135@patch ("jax.devices" )
128136def test_hbm_usage_bytes_pathways_no_arrays (mock_devices , mock_live_arrays ):
129137 """Tests hbm_usage_bytes when VLLM_TPU_USING_PATHWAYS is True but no live arrays."""
130- # Mock TPU v5e devices
138+ # Mock TPU v6e devices
131139 mock_jax_device = MagicMock ()
132140 mock_jax_device .device_kind = "TPU v6e"
133141 mock_devices .return_value = [mock_jax_device ]
@@ -141,7 +149,8 @@ def test_hbm_usage_bytes_pathways_no_arrays(mock_devices, mock_live_arrays):
141149
142150 usage = hbm_usage_bytes (devices )
143151
144- # No arrays means no memory usage
152+ # No arrays means no memory usage, defaultdict returns 0 for missing keys
153+ # HBM limit for TPU v6e is 32 GB
145154 expected_usage = [(0 , 32 * GBYTES ), (0 , 32 * GBYTES )]
146155 assert usage == expected_usage
147156
0 commit comments