@@ -75,25 +75,34 @@ def test_hbm_usage_bytes_pathways_enabled(mock_devices, mock_live_arrays):
7575 mock_device2 = MagicMock ()
7676 devices = [mock_device1 , mock_device2 ]
7777
78- # Create mock device buffers
79- mock_buffer1_dev1 = MagicMock ()
80- mock_buffer1_dev1 .device = mock_device1
81- mock_buffer1_dev1 .nbytes = 2000 # 2000 bytes on device1
78+ # Create mock addressable shards with data property
79+ mock_data1_dev1 = MagicMock ()
80+ mock_data1_dev1 .device = mock_device1
81+ mock_data1_dev1 .nbytes = 2000 # 2000 bytes on device1
8282
83- mock_buffer1_dev2 = MagicMock ()
84- mock_buffer1_dev2 .device = mock_device2
85- mock_buffer1_dev2 .nbytes = 2000 # 2000 bytes on device2
83+ mock_data1_dev2 = MagicMock ()
84+ mock_data1_dev2 .device = mock_device2
85+ mock_data1_dev2 .nbytes = 2000 # 2000 bytes on device2
8686
87- mock_buffer2_dev1 = MagicMock ()
88- mock_buffer2_dev1 .device = mock_device1
89- mock_buffer2_dev1 .nbytes = 1000 # 1000 bytes on device1
87+ mock_data2_dev1 = MagicMock ()
88+ mock_data2_dev1 .device = mock_device1
89+ mock_data2_dev1 .nbytes = 1000 # 1000 bytes on device1
9090
91- # Create mock arrays with device buffers
91+ mock_shard1_dev1 = MagicMock ()
92+ mock_shard1_dev1 .data = mock_data1_dev1
93+
94+ mock_shard1_dev2 = MagicMock ()
95+ mock_shard1_dev2 .data = mock_data1_dev2
96+
97+ mock_shard2_dev1 = MagicMock ()
98+ mock_shard2_dev1 .data = mock_data2_dev1
99+
100+ # Create mock arrays with addressable_shards
92101 mock_array1 = MagicMock ()
93- mock_array1 .device_buffers = [mock_buffer1_dev1 , mock_buffer1_dev2 ]
102+ mock_array1 .addressable_shards = [mock_shard1_dev1 , mock_shard1_dev2 ]
94103
95104 mock_array2 = MagicMock ()
96- mock_array2 .device_buffers = [mock_buffer2_dev1 ]
105+ mock_array2 .addressable_shards = [mock_shard2_dev1 ]
97106
98107 mock_live_arrays .return_value = [mock_array1 , mock_array2 ]
99108
0 commit comments