Skip to content

Commit 572b748

Browse files
committed
introduced types for init_context and process_context
1 parent 70d7f65 commit 572b748

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

test/bumblebee/text/generation_test.exs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,19 +220,19 @@ defmodule Bumblebee.Text.GenerationTest do
220220
end
221221

222222
@impl Bumblebee.LogitsProcessor
223-
def init(logits_processor, context) do
223+
def init(logits_processor, init_context) do
224224
initial_enforced_token_id = Nx.tensor([logits_processor.initial_enforced_token_id])
225225

226226
[initial_enforced_batch_token_id, _sequence] =
227-
Nx.broadcast_vectors([initial_enforced_token_id, context.sequence])
227+
Nx.broadcast_vectors([initial_enforced_token_id, init_context.sequence])
228228

229229
%{
230230
next_enforced_token_id: initial_enforced_batch_token_id
231231
}
232232
end
233233

234234
@impl Bumblebee.LogitsProcessor
235-
def process(_logits_processor, state, logits, _context) do
235+
def process(_logits_processor, state, logits, _process_context) do
236236
next_enforced_token_id = state.next_enforced_token_id
237237

238238
logits = enforce_token(logits, next_enforced_token_id)

0 commit comments

Comments
 (0)