@@ -1432,13 +1432,11 @@ def get_mrope_input_positions(
14321432 vision_start_token_id = hf_config .vision_start_token_id
14331433 spatial_merge_size = hf_config .vision_config .spatial_merge_size
14341434
1435- input_tokens_tensor = torch .tensor (input_tokens )
1436- vision_start_indices = torch .argwhere (
1437- input_tokens_tensor == vision_start_token_id
1438- ).squeeze (1 )
1439- vision_tokens = input_tokens_tensor [vision_start_indices + 1 ]
1440- image_nums = (vision_tokens == image_token_id ).sum ()
1441- video_nums = (vision_tokens == video_token_id ).sum ()
1435+ input_tokens_array = np .array (input_tokens )
1436+ vision_start_mask = input_tokens_array == vision_start_token_id
1437+ vision_tokens = input_tokens_array [vision_start_mask .nonzero ()[0 ] + 1 ]
1438+ image_nums = np .count_nonzero (vision_tokens == image_token_id )
1439+ video_nums = np .count_nonzero (vision_tokens == video_token_id )
14421440 llm_pos_ids_list : list = []
14431441
14441442 st = 0
@@ -1474,43 +1472,23 @@ def get_mrope_input_positions(
14741472
14751473 st_idx = llm_pos_ids_list [- 1 ].max () + 1 if len (llm_pos_ids_list ) > 0 else 0
14761474 llm_pos_ids_list .append (
1477- torch . arange (text_len ). view ( 1 , - 1 ). expand (3 , - 1 ) + st_idx
1475+ np . broadcast_to ( np . arange (text_len ), (3 , text_len ) ) + st_idx
14781476 )
14791477
1480- t_index = (
1481- torch .arange (llm_grid_t )
1482- .view (- 1 , 1 )
1483- .expand (- 1 , llm_grid_h * llm_grid_w )
1484- .flatten ()
1485- )
1486- h_index = (
1487- torch .arange (llm_grid_h )
1488- .view (1 , - 1 , 1 )
1489- .expand (llm_grid_t , - 1 , llm_grid_w )
1490- .flatten ()
1491- )
1492- w_index = (
1493- torch .arange (llm_grid_w )
1494- .view (1 , 1 , - 1 )
1495- .expand (llm_grid_t , llm_grid_h , - 1 )
1496- .flatten ()
1497- )
1498- llm_pos_ids_list .append (
1499- torch .stack ([t_index , h_index , w_index ]) + text_len + st_idx
1500- )
1478+ grid_indices = np .indices ((llm_grid_t , llm_grid_h , llm_grid_w ))
1479+ llm_pos_ids_list .append (grid_indices .reshape (3 , - 1 ) + text_len + st_idx )
15011480 st = ed + llm_grid_t * llm_grid_h * llm_grid_w
15021481
15031482 if st < len (input_tokens ):
15041483 st_idx = llm_pos_ids_list [- 1 ].max () + 1 if len (llm_pos_ids_list ) > 0 else 0
15051484 text_len = len (input_tokens ) - st
15061485 llm_pos_ids_list .append (
1507- torch . arange (text_len ). view ( 1 , - 1 ). expand (3 , - 1 ) + st_idx
1486+ np . broadcast_to ( np . arange (text_len ), (3 , text_len ) ) + st_idx
15081487 )
15091488
1510- llm_positions = torch . cat (llm_pos_ids_list , dim = 1 ).reshape (3 , - 1 )
1489+ llm_positions = np . concatenate (llm_pos_ids_list , axis = 1 ).reshape (3 , - 1 )
15111490 mrope_position_delta = (llm_positions .max () + 1 - len (input_tokens )).item ()
1512-
1513- return llm_positions , mrope_position_delta
1491+ return torch .from_numpy (llm_positions ), mrope_position_delta
15141492
15151493 def get_language_model (self ) -> torch .nn .Module :
15161494 return self .language_model
0 commit comments