@@ -1434,13 +1434,11 @@ def get_mrope_input_positions(
14341434 vision_start_token_id = hf_config .vision_start_token_id
14351435 spatial_merge_size = hf_config .vision_config .spatial_merge_size
14361436
1437- input_tokens_tensor = torch .tensor (input_tokens )
1438- vision_start_indices = torch .argwhere (
1439- input_tokens_tensor == vision_start_token_id
1440- ).squeeze (1 )
1441- vision_tokens = input_tokens_tensor [vision_start_indices + 1 ]
1442- image_nums = (vision_tokens == image_token_id ).sum ()
1443- video_nums = (vision_tokens == video_token_id ).sum ()
1437+ input_tokens_array = np .array (input_tokens )
1438+ vision_start_mask = input_tokens_array == vision_start_token_id
1439+ vision_tokens = input_tokens_array [vision_start_mask .nonzero ()[0 ] + 1 ]
1440+ image_nums = np .count_nonzero (vision_tokens == image_token_id )
1441+ video_nums = np .count_nonzero (vision_tokens == video_token_id )
14441442 llm_pos_ids_list : list = []
14451443
14461444 st = 0
@@ -1484,43 +1482,24 @@ def get_mrope_input_positions(
14841482
14851483 st_idx = llm_pos_ids_list [- 1 ].max () + 1 if len (llm_pos_ids_list ) > 0 else 0
14861484 llm_pos_ids_list .append (
1487- torch . arange (text_len ). view ( 1 , - 1 ). expand (3 , - 1 ) + st_idx
1485+ np . broadcast_to ( np . arange (text_len ), (3 , text_len ) ) + st_idx
14881486 )
14891487
1490- t_index = (
1491- torch .arange (llm_grid_t )
1492- .view (- 1 , 1 )
1493- .expand (- 1 , llm_grid_h * llm_grid_w )
1494- .flatten ()
1495- )
1496- h_index = (
1497- torch .arange (llm_grid_h )
1498- .view (1 , - 1 , 1 )
1499- .expand (llm_grid_t , - 1 , llm_grid_w )
1500- .flatten ()
1501- )
1502- w_index = (
1503- torch .arange (llm_grid_w )
1504- .view (1 , 1 , - 1 )
1505- .expand (llm_grid_t , llm_grid_h , - 1 )
1506- .flatten ()
1507- )
1508- llm_pos_ids_list .append (
1509- torch .stack ([t_index , h_index , w_index ]) + text_len + st_idx
1510- )
1488+ grid_indices = np .indices ((llm_grid_t , llm_grid_h , llm_grid_w ))
1489+ llm_pos_ids_list .append (grid_indices .reshape (3 , - 1 ) + text_len + st_idx )
15111490 st = ed + llm_grid_t * llm_grid_h * llm_grid_w
15121491
15131492 if st < len (input_tokens ):
15141493 st_idx = llm_pos_ids_list [- 1 ].max () + 1 if len (llm_pos_ids_list ) > 0 else 0
15151494 text_len = len (input_tokens ) - st
15161495 llm_pos_ids_list .append (
1517- torch . arange (text_len ). view ( 1 , - 1 ). expand (3 , - 1 ) + st_idx
1496+ np . broadcast_to ( np . arange (text_len ), (3 , text_len ) ) + st_idx
15181497 )
15191498
1520- llm_positions = torch . cat (llm_pos_ids_list , dim = 1 ).reshape (3 , - 1 )
1499+ llm_positions = np . concatenate (llm_pos_ids_list , axis = 1 ).reshape (3 , - 1 )
15211500 mrope_position_delta = (llm_positions .max () + 1 - len (input_tokens )).item ()
15221501 llm_positions = llm_positions [:, context_len :seq_len ]
1523- return llm_positions , mrope_position_delta
1502+ return torch . from_numpy ( llm_positions ) , mrope_position_delta
15241503
15251504 def get_language_model (self ) -> torch .nn .Module :
15261505 return self .language_model
0 commit comments