Skip to content

Commit f817453

Browse files
Update Pathways HBM calculation (#1097)
Signed-off-by: wenxindongwork <wenxindong@google.com>
1 parent 9a5d703 commit f817453

File tree

2 files changed

+24
-15
lines changed

2 files changed

+24
-15
lines changed

tests/test_utils.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -75,25 +75,34 @@ def test_hbm_usage_bytes_pathways_enabled(mock_devices, mock_live_arrays):
7575
mock_device2 = MagicMock()
7676
devices = [mock_device1, mock_device2]
7777

78-
# Create mock device buffers
79-
mock_buffer1_dev1 = MagicMock()
80-
mock_buffer1_dev1.device = mock_device1
81-
mock_buffer1_dev1.nbytes = 2000 # 2000 bytes on device1
78+
# Create mock addressable shards with data property
79+
mock_data1_dev1 = MagicMock()
80+
mock_data1_dev1.device = mock_device1
81+
mock_data1_dev1.nbytes = 2000 # 2000 bytes on device1
8282

83-
mock_buffer1_dev2 = MagicMock()
84-
mock_buffer1_dev2.device = mock_device2
85-
mock_buffer1_dev2.nbytes = 2000 # 2000 bytes on device2
83+
mock_data1_dev2 = MagicMock()
84+
mock_data1_dev2.device = mock_device2
85+
mock_data1_dev2.nbytes = 2000 # 2000 bytes on device2
8686

87-
mock_buffer2_dev1 = MagicMock()
88-
mock_buffer2_dev1.device = mock_device1
89-
mock_buffer2_dev1.nbytes = 1000 # 1000 bytes on device1
87+
mock_data2_dev1 = MagicMock()
88+
mock_data2_dev1.device = mock_device1
89+
mock_data2_dev1.nbytes = 1000 # 1000 bytes on device1
9090

91-
# Create mock arrays with device buffers
91+
mock_shard1_dev1 = MagicMock()
92+
mock_shard1_dev1.data = mock_data1_dev1
93+
94+
mock_shard1_dev2 = MagicMock()
95+
mock_shard1_dev2.data = mock_data1_dev2
96+
97+
mock_shard2_dev1 = MagicMock()
98+
mock_shard2_dev1.data = mock_data2_dev1
99+
100+
# Create mock arrays with addressable_shards
92101
mock_array1 = MagicMock()
93-
mock_array1.device_buffers = [mock_buffer1_dev1, mock_buffer1_dev2]
102+
mock_array1.addressable_shards = [mock_shard1_dev1, mock_shard1_dev2]
94103

95104
mock_array2 = MagicMock()
96-
mock_array2.device_buffers = [mock_buffer2_dev1]
105+
mock_array2.addressable_shards = [mock_shard2_dev1]
97106

98107
mock_live_arrays.return_value = [mock_array1, mock_array2]
99108

tpu_inference/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,8 @@ def pathways_hbm_usage_gb(devices: Any) -> List[Tuple[float, float]]:
132132
hbm_used = defaultdict(int)
133133
hbm_limit = get_device_hbm_limit()
134134
for array in live_arrays:
135-
for buffer in array.device_buffers:
136-
hbm_used[buffer.device] += buffer.nbytes
135+
for buffer in array.addressable_shards:
136+
hbm_used[buffer.data.device] += buffer.data.nbytes
137137
return [(hbm_used[device], hbm_limit) for device in devices]
138138

139139

0 commit comments

Comments
 (0)