Skip to content

Commit 41dd2ad

Browse files
committed
vectorize sequence for init, derive vectorized state
1 parent c8442e0 commit 41dd2ad

File tree

2 files changed

+6
-8
lines changed

2 files changed

+6
-8
lines changed

lib/bumblebee/text/generation.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,7 @@ defmodule Bumblebee.Text.Generation do
592592
finished_length = Nx.select(padded_batch_item?, 1, 0)
593593

594594
context = %{
595-
sequences: sequences,
595+
sequence: Nx.vectorize(sequences, :batch),
596596
input_length: length,
597597
length: length
598598
}

test/bumblebee/text/generation_test.exs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,10 @@ defmodule Bumblebee.Text.GenerationTest do
233233

234234
@impl Bumblebee.LogitsProcessor
235235
def init(logits_processor, context) do
236-
batch_size = Nx.axis_size(context.sequences, 0)
236+
initial_enforced_token_id = Nx.tensor([logits_processor.initial_enforced_token_id])
237237

238-
initial_enforced_batch_token_id =
239-
Nx.broadcast(logits_processor.initial_enforced_token_id, {batch_size, 1})
238+
[initial_enforced_batch_token_id, _sequence] =
239+
Nx.broadcast_vectors([initial_enforced_token_id, context.sequence])
240240

241241
%{
242242
next_enforced_token_id: initial_enforced_batch_token_id
@@ -245,13 +245,11 @@ defmodule Bumblebee.Text.GenerationTest do
245245

246246
@impl Bumblebee.LogitsProcessor
247247
def process(_logits_processor, state, logits, _context) do
248-
next_enforced_token_id = Nx.vectorize(state.next_enforced_token_id, :batch)
248+
next_enforced_token_id = state.next_enforced_token_id
249249

250250
logits = enforce_token(logits, next_enforced_token_id)
251251

252-
next_enforced_token_id =
253-
Nx.add(next_enforced_token_id, 1)
254-
|> Nx.devectorize(keep_names: false)
252+
next_enforced_token_id = Nx.add(next_enforced_token_id, 1)
255253

256254
state = put_in(state.next_enforced_token_id, next_enforced_token_id)
257255

0 commit comments

Comments
 (0)