@@ -174,11 +174,11 @@ def test_contexted_kv_attention(
174174 block_table = values [: BS * max_block_per_request ].view (BS , max_block_per_request )
175175 b_seq_len = torch .tensor (seq_lens , dtype = torch .int32 )
176176 b_ctx_len = torch .tensor (ctx_lens , dtype = torch .int32 )
177- b_start_loc = torch .cumsum (torch .tensor ([0 ] + query_lens , dtype = torch . int32 ), dim = 0 )
177+ b_start_loc = torch .cumsum (torch .tensor ([0 ] + query_lens ), dim = 0 ). to ( torch . int32 )
178178 max_input_len = MAX_SEQ_LEN
179179 # copy kv to cache
180- b_seq_start_loc = torch .cumsum (
181- torch .tensor ([ 0 ] + seq_lens [: - 1 ], dtype = torch . int32 ), dim = 0
180+ b_seq_start_loc = torch .cumsum (torch . tensor ([ 0 ] + seq_lens [: - 1 ]), dim = 0 ). to (
181+ torch .int32
182182 )
183183 for i in range (BS ):
184184 for j in range (query_lens [i ]):
@@ -417,11 +417,11 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
417417 block_table = values [: BS * max_block_per_request ].view (BS , max_block_per_request )
418418 b_seq_len = torch .tensor (seq_lens , dtype = torch .int32 )
419419 b_ctx_len = torch .tensor (ctx_lens , dtype = torch .int32 )
420- b_start_loc = torch .cumsum (torch .tensor ([0 ] + query_lens , dtype = torch . int32 ), dim = 0 )
420+ b_start_loc = torch .cumsum (torch .tensor ([0 ] + query_lens ), dim = 0 ). to ( torch . int32 )
421421 max_input_len = MAX_SEQ_LEN
422422 # copy kv to cache
423- b_seq_start_loc = torch .cumsum (
424- torch .tensor ([ 0 ] + seq_lens [: - 1 ], dtype = torch . int32 ), dim = 0
423+ b_seq_start_loc = torch .cumsum (torch . tensor ([ 0 ] + seq_lens [: - 1 ]), dim = 0 ). to (
424+ torch .int32
425425 )
426426 for i in range (BS ):
427427 for j in range (query_lens [i ]):
0 commit comments