@@ -233,10 +233,10 @@ defmodule Bumblebee.Text.GenerationTest do
233233
234234 @ impl Bumblebee.LogitsProcessor
235235 def init ( logits_processor , context ) do
236- batch_size = Nx . axis_size ( context . sequences , 0 )
236+ initial_enforced_token_id = Nx . tensor ( [ logits_processor . initial_enforced_token_id ] )
237237
238- initial_enforced_batch_token_id =
239- Nx . broadcast ( logits_processor . initial_enforced_token_id , { batch_size , 1 } )
238+ [ initial_enforced_batch_token_id , _sequence ] =
239+ Nx . broadcast_vectors ( [ initial_enforced_token_id , context . sequence ] )
240240
241241 % {
242242 next_enforced_token_id: initial_enforced_batch_token_id
@@ -245,13 +245,11 @@ defmodule Bumblebee.Text.GenerationTest do
245245
246246 @ impl Bumblebee.LogitsProcessor
247247 def process ( _logits_processor , state , logits , _context ) do
248- next_enforced_token_id = Nx . vectorize ( state . next_enforced_token_id , :batch )
248+ next_enforced_token_id = state . next_enforced_token_id
249249
250250 logits = enforce_token ( logits , next_enforced_token_id )
251251
252- next_enforced_token_id =
253- Nx . add ( next_enforced_token_id , 1 )
254- |> Nx . devectorize ( keep_names: false )
252+ next_enforced_token_id = Nx . add ( next_enforced_token_id , 1 )
255253
256254 state = put_in ( state . next_enforced_token_id , next_enforced_token_id )
257255
0 commit comments