From fc0825a460cd6ac68e6becae9f43cbe56f2499cd Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Fri, 17 Oct 2025 11:59:21 +0200 Subject: [PATCH 01/36] [#SAMPLE-6] Add state to logits processing https://bitcrowd.atlassian.net/browse/SAMPLE-6 From 01ab3af7c2c4b9844b753f51c38dfaf96e265417 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Thu, 16 Oct 2025 18:04:54 +0200 Subject: [PATCH 02/36] stateful logits processors --- lib/bumblebee/text/generation.ex | 45 +++++++++++------ .../generation/logits_processing_test.exs | 50 ++++++++++++++++++- 2 files changed, 79 insertions(+), 16 deletions(-) diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index 935c4921..acafabde 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -387,8 +387,12 @@ defmodule Bumblebee.Text.Generation do end ++ logits_processors fn logits, context -> - for processor <- processors, processor, reduce: logits do - logits -> processor.(logits, context) + for processor <- processors, processor, reduce: {logits, context} do + {logits, context} -> + case processor.(logits, context) do + {logits, new_context} -> {logits, new_context} + logits -> {logits, context} + end end end end @@ -551,7 +555,8 @@ defmodule Bumblebee.Text.Generation do length: length, finished_length: finished_length, # The ignored return value that we attach all hooks to - ignored: Nx.broadcast(0, {batch_size}) + ignored: Nx.broadcast(0, {batch_size}), + logits_processor_states: %{} } end @@ -574,7 +579,7 @@ defmodule Bumblebee.Text.Generation do outputs = predict_fun.(params, inputs) logits = outputs.logits[[.., -1]] - logits = batch_process_logits(logits_processor_fun, logits, state) + {logits, state} = batch_process_logits(logits_processor_fun, logits, state) token_id = Nx.argmax(logits, axis: -1) state = update_sequences(state, token_id, pad_token_id, eos_token_id) @@ -632,14 +637,24 @@ defmodule Bumblebee.Text.Generation do end defnp batch_process_logits(logits_processor_fun, logits, state) do - logits - |> Nx.vectorize(:batch) - |> logits_processor_fun.(%{ - sequence: Nx.vectorize(state.sequences, :batch), - length: state.length, - input_length: state.input_length - }) - |> Nx.devectorize(keep_names: false) + logits = Nx.vectorize(logits, :batch) + + {logits, new_context} = + logits_processor_fun.(logits, %{ + sequence: Nx.vectorize(state.sequences, :batch), + length: state.length, + input_length: state.input_length, + logits_processor_states: state.logits_processor_states + }) + + logits = Nx.devectorize(logits, keep_names: false) + + logits_processor_states = + Nx.devectorize(new_context.logits_processor_states, keep_names: false) + + sequences = Nx.devectorize(new_context.sequence, keep_names: false) + + {logits, %{state | sequences: sequences, logits_processor_states: logits_processor_states}} end # Contrastive search @@ -684,7 +699,7 @@ defmodule Bumblebee.Text.Generation do joint_hidden_state = Nx.put_slice(joint_hidden_state, [0, 0, 0], initial_hidden_state) logits = outputs.logits[[.., -1]] - logits = batch_process_logits(logits_processor_fun, logits, state) + {logits, state} = batch_process_logits(logits_processor_fun, logits, state) scores = Axon.Activations.softmax(logits, axis: -1) {top_k_scores, top_k_token_ids} = Nx.top_k(scores, k: top_k) @@ -727,7 +742,7 @@ defmodule Bumblebee.Text.Generation do logits = outputs.logits[[.., -1]] logits = Utils.Nx.chunked_take(logits, top_k, selected_idx) - logits = batch_process_logits(logits_processor_fun, logits, state) + {logits, state} = batch_process_logits(logits_processor_fun, logits, state) scores = Axon.Activations.softmax(logits, axis: -1) {top_k_scores, top_k_token_ids} = Nx.top_k(scores, k: top_k) @@ -888,7 +903,7 @@ defmodule Bumblebee.Text.Generation do outputs = predict_fun.(params, inputs) logits = outputs.logits[[.., -1]] - logits = batch_process_logits(logits_processor_fun, logits, state) + {logits, state} = batch_process_logits(logits_processor_fun, logits, state) scores = Axon.Activations.softmax(logits) token_id = batched_choice(key, scores) diff --git a/test/bumblebee/text/generation/logits_processing_test.exs b/test/bumblebee/text/generation/logits_processing_test.exs index 5bc5a44f..dc223a13 100644 --- a/test/bumblebee/text/generation/logits_processing_test.exs +++ b/test/bumblebee/text/generation/logits_processing_test.exs @@ -5,6 +5,53 @@ defmodule Bumblebee.Text.Generation.LogitsProcessingTest do alias Bumblebee.Text.Generation.LogitsProcessing + describe "stateful logits processors" do + defmodule StatefulLogitsProcessing do + import Nx.Defn + + deftransform stateful_processor(logits, context, opts) do + initial_suppressed_index = Nx.tensor([opts[:initial_suppressed_index]]) + + suppressed_index = + context.logits_processor_states[:next_suppressed_index] || initial_suppressed_index + + values = + Nx.broadcast(Nx.Constants.neg_infinity(Nx.type(logits)), Nx.size(suppressed_index)) + + logits = Nx.indexed_put(logits, suppressed_index, values) + + next_suppressed_index = Nx.add(suppressed_index, Nx.tensor([1])) + + context = + put_in( + context, + [:logits_processor_states, :next_suppressed_index], + next_suppressed_index + ) + + {logits, context} + end + end + + test "can register and modify state" do + logits = Nx.tensor([1.0, 2.0, 3.0, 4.0]) + + context = context([1, 0, 0, 0]) + + {logits, context} = + StatefulLogitsProcessing.stateful_processor(logits, context, initial_suppressed_index: 0) + + assert_equal(logits, Nx.tensor([:neg_infinity, 2.0, 3.0, 4.0])) + assert_equal(context.logits_processor_states.next_suppressed_index, Nx.tensor([1])) + + {logits, context} = + StatefulLogitsProcessing.stateful_processor(logits, context, initial_suppressed_index: 0) + + assert_equal(logits, Nx.tensor([:neg_infinity, :neg_infinity, 3.0, 4.0])) + assert_equal(context.logits_processor_states.next_suppressed_index, Nx.tensor([2])) + end + end + describe "suppressed_tokens_processor/3" do test "ignores the given tokens" do logits = Nx.tensor([1.0, 2.0, 3.0, 4.0]) @@ -382,7 +429,8 @@ defmodule Bumblebee.Text.Generation.LogitsProcessingTest do %{ sequence: Nx.tensor(sequence), length: Enum.count(sequence, &(&1 != 0)), - input_length: 1 + input_length: 1, + logits_processor_states: %{} } end end From 5413662d1517525451da12dda31f8b3d0b1a3fee Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Thu, 16 Oct 2025 18:24:07 +0200 Subject: [PATCH 03/36] adding another test --- lib/bumblebee/text/generation.ex | 1 + test/bumblebee/text/generation_test.exs | 57 +++++++++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index acafabde..8396497d 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -644,6 +644,7 @@ defmodule Bumblebee.Text.Generation do sequence: Nx.vectorize(state.sequences, :batch), length: state.length, input_length: state.input_length, + # logits_processor_state: Nx.vectorize(state.logits_processor_states, :batch) logits_processor_states: state.logits_processor_states }) diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index ff9854a1..f1782a13 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -106,4 +106,61 @@ defmodule Bumblebee.Text.GenerationTest do assert_equal(token_ids, Nx.tensor([[80, 1023, 1023]])) end + + test "with stateful logits processor" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"}) + + {:ok, generation_config} = + Bumblebee.load_generation_config({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"}) + + assert %Bumblebee.Text.Gpt2{architecture: :for_causal_language_modeling} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[0, 0, 10, 20, 30, 40, 50, 60, 70, 80]]), + "attention_mask" => Nx.tensor([[0, 0, 1, 1, 1, 1, 1, 1, 1, 1]]), + "seed" => Nx.tensor([0]) + } + + generation_config = Bumblebee.configure(generation_config, max_new_tokens: 3) + + generate = + Bumblebee.Text.Generation.build_generate(model, spec, generation_config, + logits_processors: + &Bumblebee.Text.GenerationTest.StatefulLogitsProcessing.stateful_processor(&1, &2, + initial_suppressed_index: 0 + ) + ) + + %{token_ids: token_ids} = generate.(params, inputs) + + assert_equal(token_ids, Nx.tensor([[80, 80, 80]])) + end + + defmodule StatefulLogitsProcessing do + import Nx.Defn + + deftransform stateful_processor(logits, context, opts) do + initial_suppressed_index = Nx.tensor([opts[:initial_suppressed_index]]) + + suppressed_index = + context.logits_processor_states[:next_suppressed_index] || initial_suppressed_index + + values = + Nx.broadcast(Nx.Constants.neg_infinity(Nx.type(logits)), Nx.size(suppressed_index)) + + logits = Nx.indexed_put(logits, suppressed_index, values) + + next_suppressed_index = Nx.add(suppressed_index, Nx.tensor([1])) + + context = + put_in( + context, + [:logits_processor_states, :next_suppressed_index], + next_suppressed_index + ) + + {logits, context} + end + end end From 9d4ef39c4fe52fe6e70a6bff147aa835c1698fd9 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Mon, 20 Oct 2025 15:25:04 +0200 Subject: [PATCH 04/36] fix test so compilation works --- test/bumblebee/text/generation_test.exs | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index f1782a13..1c8c0fe1 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -126,10 +126,11 @@ defmodule Bumblebee.Text.GenerationTest do generate = Bumblebee.Text.Generation.build_generate(model, spec, generation_config, - logits_processors: + logits_processors: [ &Bumblebee.Text.GenerationTest.StatefulLogitsProcessing.stateful_processor(&1, &2, initial_suppressed_index: 0 ) + ] ) %{token_ids: token_ids} = generate.(params, inputs) @@ -140,18 +141,15 @@ defmodule Bumblebee.Text.GenerationTest do defmodule StatefulLogitsProcessing do import Nx.Defn - deftransform stateful_processor(logits, context, opts) do + deftransform stateful_processor(logits, context, opts \\ []) do initial_suppressed_index = Nx.tensor([opts[:initial_suppressed_index]]) suppressed_index = context.logits_processor_states[:next_suppressed_index] || initial_suppressed_index - values = - Nx.broadcast(Nx.Constants.neg_infinity(Nx.type(logits)), Nx.size(suppressed_index)) + logits = suppress_index(logits, suppressed_index) - logits = Nx.indexed_put(logits, suppressed_index, values) - - next_suppressed_index = Nx.add(suppressed_index, Nx.tensor([1])) + next_suppressed_index = Nx.add(suppressed_index, 1) context = put_in( @@ -162,5 +160,9 @@ defmodule Bumblebee.Text.GenerationTest do {logits, context} end + + defnp suppress_index(logits, index) do + Nx.indexed_put(logits, index, Nx.Constants.neg_infinity(Nx.type(logits))) + end end end From 4ce01cc47e1d0f60994f1247b0ce05a8affedfa4 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Mon, 20 Oct 2025 15:42:50 +0200 Subject: [PATCH 05/36] demonstrate stateful logits processor through test assertions --- test/bumblebee/text/generation_test.exs | 27 +++++++++++++++---------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index 1c8c0fe1..b614fd45 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -128,41 +128,46 @@ defmodule Bumblebee.Text.GenerationTest do Bumblebee.Text.Generation.build_generate(model, spec, generation_config, logits_processors: [ &Bumblebee.Text.GenerationTest.StatefulLogitsProcessing.stateful_processor(&1, &2, - initial_suppressed_index: 0 + initial_suppressed_id: 79 ) ] ) %{token_ids: token_ids} = generate.(params, inputs) - assert_equal(token_ids, Nx.tensor([[80, 80, 80]])) + # result without logit processor: 80, 80 + + # first token_id still 80 as we suppress token_id 79 + assert_equal(token_ids[[0, 0]], 80) + # in the next step we increment from 79 to 80 and suppress token_id 80 + assert_equal(token_ids[[0, 1]], 176) end defmodule StatefulLogitsProcessing do import Nx.Defn deftransform stateful_processor(logits, context, opts \\ []) do - initial_suppressed_index = Nx.tensor([opts[:initial_suppressed_index]]) + initial_suppressed_id = Nx.tensor([opts[:initial_suppressed_id]]) - suppressed_index = - context.logits_processor_states[:next_suppressed_index] || initial_suppressed_index + suppressed_id = + context.logits_processor_states[:next_suppressed_id] || initial_suppressed_id - logits = suppress_index(logits, suppressed_index) + logits = suppress_id(logits, suppressed_id) - next_suppressed_index = Nx.add(suppressed_index, 1) + next_suppressed_id = Nx.add(suppressed_id, 1) context = put_in( context, - [:logits_processor_states, :next_suppressed_index], - next_suppressed_index + [:logits_processor_states, :next_suppressed_id], + next_suppressed_id ) {logits, context} end - defnp suppress_index(logits, index) do - Nx.indexed_put(logits, index, Nx.Constants.neg_infinity(Nx.type(logits))) + defnp suppress_id(logits, id) do + Nx.indexed_put(logits, id, Nx.Constants.neg_infinity(Nx.type(logits))) end end end From 2161b775b7ba884fda0bfaa85836b5b38b735b61 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Mon, 20 Oct 2025 16:50:19 +0200 Subject: [PATCH 06/36] independent state for batch entries --- lib/bumblebee/text/generation.ex | 5 ++-- test/bumblebee/text/generation_test.exs | 40 ++++++++++++++++++------- 2 files changed, 31 insertions(+), 14 deletions(-) diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index 8396497d..51ccb58a 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -644,14 +644,13 @@ defmodule Bumblebee.Text.Generation do sequence: Nx.vectorize(state.sequences, :batch), length: state.length, input_length: state.input_length, - # logits_processor_state: Nx.vectorize(state.logits_processor_states, :batch) - logits_processor_states: state.logits_processor_states + logits_processor_state: Nx.vectorize(state.logits_processor_states, :batch) }) logits = Nx.devectorize(logits, keep_names: false) logits_processor_states = - Nx.devectorize(new_context.logits_processor_states, keep_names: false) + Nx.devectorize(new_context.logits_processor_state, keep_names: false) sequences = Nx.devectorize(new_context.sequence, keep_names: false) diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index b614fd45..844fd43d 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -116,10 +116,14 @@ defmodule Bumblebee.Text.GenerationTest do assert %Bumblebee.Text.Gpt2{architecture: :for_causal_language_modeling} = spec + input_ids = Nx.tensor([[0, 0, 10, 20, 30, 40, 50, 60, 70, 80]]) + attention_mask = Nx.tensor([[0, 0, 1, 1, 1, 1, 1, 1, 1, 1]]) + seed = Nx.tensor([0]) + inputs = %{ - "input_ids" => Nx.tensor([[0, 0, 10, 20, 30, 40, 50, 60, 70, 80]]), - "attention_mask" => Nx.tensor([[0, 0, 1, 1, 1, 1, 1, 1, 1, 1]]), - "seed" => Nx.tensor([0]) + "input_ids" => Nx.Batch.concatenate([input_ids, input_ids]), + "attention_mask" => Nx.Batch.concatenate([attention_mask, attention_mask]), + "seed" => Nx.Batch.concatenate([seed, seed]) } generation_config = Bumblebee.configure(generation_config, max_new_tokens: 3) @@ -128,29 +132,39 @@ defmodule Bumblebee.Text.GenerationTest do Bumblebee.Text.Generation.build_generate(model, spec, generation_config, logits_processors: [ &Bumblebee.Text.GenerationTest.StatefulLogitsProcessing.stateful_processor(&1, &2, - initial_suppressed_id: 79 + initial_suppressed_id: [78, 79] ) ] ) %{token_ids: token_ids} = generate.(params, inputs) - # result without logit processor: 80, 80 + # result without logit processor: 80, 80, 80 - # first token_id still 80 as we suppress token_id 79 + # first entry in batch + # first token_id still 80 as we suppress token_id 78 assert_equal(token_ids[[0, 0]], 80) + # second token_id still 80 as we suppress token_id 79 + assert_equal(token_ids[[0, 1]], 80) # in the next step we increment from 79 to 80 and suppress token_id 80 - assert_equal(token_ids[[0, 1]], 176) + assert_equal(token_ids[[0, 2]], 1016) + + # second entry in batch + # first token_id still 80 as we suppress token_id 79 + assert_equal(token_ids[[1, 0]], 80) + # in the next step we increment from 79 to 80 and suppress token_id 80 + assert_equal(token_ids[[1, 1]], 176) end defmodule StatefulLogitsProcessing do import Nx.Defn deftransform stateful_processor(logits, context, opts \\ []) do - initial_suppressed_id = Nx.tensor([opts[:initial_suppressed_id]]) + initial_suppressed_ids = Enum.map(opts[:initial_suppressed_id], &List.wrap(&1)) + initial_suppressed_id = Nx.tensor(initial_suppressed_ids) |> Nx.vectorize(:batch) suppressed_id = - context.logits_processor_states[:next_suppressed_id] || initial_suppressed_id + context.logits_processor_state[:next_suppressed_id] || initial_suppressed_id logits = suppress_id(logits, suppressed_id) @@ -159,7 +173,7 @@ defmodule Bumblebee.Text.GenerationTest do context = put_in( context, - [:logits_processor_states, :next_suppressed_id], + [:logits_processor_state, :next_suppressed_id], next_suppressed_id ) @@ -167,7 +181,11 @@ defmodule Bumblebee.Text.GenerationTest do end defnp suppress_id(logits, id) do - Nx.indexed_put(logits, id, Nx.Constants.neg_infinity(Nx.type(logits))) + Nx.indexed_put( + logits, + id, + Nx.Constants.neg_infinity(Nx.type(logits)) + ) end end end From fefc9fdf548436707bea4d5476a87c96cd074255 Mon Sep 17 00:00:00 2001 From: Chris Date: Tue, 21 Oct 2025 15:57:01 +0200 Subject: [PATCH 07/36] renamed initial_suppressed_token_index for clarity --- test/bumblebee/text/generation/logits_processing_test.exs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/bumblebee/text/generation/logits_processing_test.exs b/test/bumblebee/text/generation/logits_processing_test.exs index dc223a13..2ebc03b9 100644 --- a/test/bumblebee/text/generation/logits_processing_test.exs +++ b/test/bumblebee/text/generation/logits_processing_test.exs @@ -10,10 +10,10 @@ defmodule Bumblebee.Text.Generation.LogitsProcessingTest do import Nx.Defn deftransform stateful_processor(logits, context, opts) do - initial_suppressed_index = Nx.tensor([opts[:initial_suppressed_index]]) + initial_suppressed_token_index = Nx.tensor([opts[:initial_suppressed_token_index]]) suppressed_index = - context.logits_processor_states[:next_suppressed_index] || initial_suppressed_index + context.logits_processor_states[:next_suppressed_index] || initial_suppressed_token_index values = Nx.broadcast(Nx.Constants.neg_infinity(Nx.type(logits)), Nx.size(suppressed_index)) @@ -39,13 +39,13 @@ defmodule Bumblebee.Text.Generation.LogitsProcessingTest do context = context([1, 0, 0, 0]) {logits, context} = - StatefulLogitsProcessing.stateful_processor(logits, context, initial_suppressed_index: 0) + StatefulLogitsProcessing.stateful_processor(logits, context, initial_suppressed_token_index: 0) assert_equal(logits, Nx.tensor([:neg_infinity, 2.0, 3.0, 4.0])) assert_equal(context.logits_processor_states.next_suppressed_index, Nx.tensor([1])) {logits, context} = - StatefulLogitsProcessing.stateful_processor(logits, context, initial_suppressed_index: 0) + StatefulLogitsProcessing.stateful_processor(logits, context, initial_suppressed_token_index: 0) assert_equal(logits, Nx.tensor([:neg_infinity, :neg_infinity, 3.0, 4.0])) assert_equal(context.logits_processor_states.next_suppressed_index, Nx.tensor([2])) From 6e8612a71e9d1af0cf89ac3d9f267fb670da9275 Mon Sep 17 00:00:00 2001 From: Chris Date: Tue, 21 Oct 2025 16:29:30 +0200 Subject: [PATCH 08/36] renamend next_suppressed_index -> :next_suppressed_token_index --- .../text/generation/logits_processing_test.exs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/bumblebee/text/generation/logits_processing_test.exs b/test/bumblebee/text/generation/logits_processing_test.exs index 2ebc03b9..6d41e52c 100644 --- a/test/bumblebee/text/generation/logits_processing_test.exs +++ b/test/bumblebee/text/generation/logits_processing_test.exs @@ -13,20 +13,20 @@ defmodule Bumblebee.Text.Generation.LogitsProcessingTest do initial_suppressed_token_index = Nx.tensor([opts[:initial_suppressed_token_index]]) suppressed_index = - context.logits_processor_states[:next_suppressed_index] || initial_suppressed_token_index + context.logits_processor_states[:next_suppressed_token_index] || initial_suppressed_token_index values = Nx.broadcast(Nx.Constants.neg_infinity(Nx.type(logits)), Nx.size(suppressed_index)) logits = Nx.indexed_put(logits, suppressed_index, values) - next_suppressed_index = Nx.add(suppressed_index, Nx.tensor([1])) + next_suppressed_token_index = Nx.add(suppressed_index, Nx.tensor([1])) context = put_in( context, - [:logits_processor_states, :next_suppressed_index], - next_suppressed_index + [:logits_processor_states, :next_suppressed_token_index], + next_suppressed_token_index ) {logits, context} @@ -42,13 +42,13 @@ defmodule Bumblebee.Text.Generation.LogitsProcessingTest do StatefulLogitsProcessing.stateful_processor(logits, context, initial_suppressed_token_index: 0) assert_equal(logits, Nx.tensor([:neg_infinity, 2.0, 3.0, 4.0])) - assert_equal(context.logits_processor_states.next_suppressed_index, Nx.tensor([1])) + assert_equal(context.logits_processor_states.next_suppressed_token_index, Nx.tensor([1])) {logits, context} = StatefulLogitsProcessing.stateful_processor(logits, context, initial_suppressed_token_index: 0) assert_equal(logits, Nx.tensor([:neg_infinity, :neg_infinity, 3.0, 4.0])) - assert_equal(context.logits_processor_states.next_suppressed_index, Nx.tensor([2])) + assert_equal(context.logits_processor_states.next_suppressed_token_index, Nx.tensor([2])) end end From e43254a02afdce62a0d3616dafb9d6dafb5d42f0 Mon Sep 17 00:00:00 2001 From: Chris Date: Tue, 21 Oct 2025 16:32:49 +0200 Subject: [PATCH 09/36] logits_processor_states -> logits_processor_state in batch tests --- .../text/generation/logits_processing_test.exs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/bumblebee/text/generation/logits_processing_test.exs b/test/bumblebee/text/generation/logits_processing_test.exs index 6d41e52c..1aefabdd 100644 --- a/test/bumblebee/text/generation/logits_processing_test.exs +++ b/test/bumblebee/text/generation/logits_processing_test.exs @@ -13,7 +13,7 @@ defmodule Bumblebee.Text.Generation.LogitsProcessingTest do initial_suppressed_token_index = Nx.tensor([opts[:initial_suppressed_token_index]]) suppressed_index = - context.logits_processor_states[:next_suppressed_token_index] || initial_suppressed_token_index + context.logits_processor_state[:next_suppressed_token_index] || initial_suppressed_token_index values = Nx.broadcast(Nx.Constants.neg_infinity(Nx.type(logits)), Nx.size(suppressed_index)) @@ -25,7 +25,7 @@ defmodule Bumblebee.Text.Generation.LogitsProcessingTest do context = put_in( context, - [:logits_processor_states, :next_suppressed_token_index], + [:logits_processor_state, :next_suppressed_token_index], next_suppressed_token_index ) @@ -42,13 +42,13 @@ defmodule Bumblebee.Text.Generation.LogitsProcessingTest do StatefulLogitsProcessing.stateful_processor(logits, context, initial_suppressed_token_index: 0) assert_equal(logits, Nx.tensor([:neg_infinity, 2.0, 3.0, 4.0])) - assert_equal(context.logits_processor_states.next_suppressed_token_index, Nx.tensor([1])) + assert_equal(context.logits_processor_state.next_suppressed_token_index, Nx.tensor([1])) {logits, context} = StatefulLogitsProcessing.stateful_processor(logits, context, initial_suppressed_token_index: 0) assert_equal(logits, Nx.tensor([:neg_infinity, :neg_infinity, 3.0, 4.0])) - assert_equal(context.logits_processor_states.next_suppressed_token_index, Nx.tensor([2])) + assert_equal(context.logits_processor_state.next_suppressed_token_index, Nx.tensor([2])) end end @@ -430,7 +430,7 @@ defmodule Bumblebee.Text.Generation.LogitsProcessingTest do sequence: Nx.tensor(sequence), length: Enum.count(sequence, &(&1 != 0)), input_length: 1, - logits_processor_states: %{} + logits_processor_state: %{} } end end From a2f0015c69f720a41a2db64662b7b673d95528f0 Mon Sep 17 00:00:00 2001 From: Chris Date: Tue, 21 Oct 2025 17:02:53 +0200 Subject: [PATCH 10/36] added a test with batch size 1 for clarity --- test/bumblebee/text/generation_test.exs | 59 +++++++++++++++++++++++-- 1 file changed, 56 insertions(+), 3 deletions(-) diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index 844fd43d..227cd1b6 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -107,7 +107,57 @@ defmodule Bumblebee.Text.GenerationTest do assert_equal(token_ids, Nx.tensor([[80, 1023, 1023]])) end - test "with stateful logits processor" do + + test "with stateful logits processor with batch size of 1" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"}) + + {:ok, generation_config} = + Bumblebee.load_generation_config({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"}) + + assert %Bumblebee.Text.Gpt2{architecture: :for_causal_language_modeling} = spec + + input_ids = Nx.tensor([[0, 0, 10, 20, 30, 40, 50, 60, 70, 80]]) + attention_mask = Nx.tensor([[0, 0, 1, 1, 1, 1, 1, 1, 1, 1]]) + seed = Nx.tensor([0]) + + inputs = %{ + "input_ids" => input_ids, + "attention_mask" => attention_mask, + "seed" => seed + } + + # in the first decoder test above we got 80, 80, 80 consistently. + # we use the stateful processor below + # it suppresses the given initial id on the first iteration, then increments + #the id to be suppressed on the following iterations. + # So it surpresses 79, 80, ... in the iterations, demonstrating the use + # of the state in a processor. + + generation_config = Bumblebee.configure(generation_config, max_new_tokens: 2) + + generate = + Bumblebee.Text.Generation.build_generate(model, spec, generation_config, + logits_processors: [ + &Bumblebee.Text.GenerationTest.StatefulLogitsProcessing.stateful_processor(&1, &2, + initial_suppressed_id: [79] + ) + ] + ) + + %{token_ids: token_ids} = generate.(params, inputs) + + # result without logit processor: 80, 80, 80 + + # first token_id still 80 as we suppress token_id 79 + assert_equal(token_ids[[0,0]], 80) + # in the next step we increment from 79 to 80 and suppress token_id 80, the + #result is 176 as that is the next likelihood in the logits. + + assert_equal(token_ids[[0,1]], 176) + end + + test "with stateful logits processor with batch size of 2" do assert {:ok, %{model: model, params: params, spec: spec}} = Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"}) @@ -126,6 +176,9 @@ defmodule Bumblebee.Text.GenerationTest do "seed" => Nx.Batch.concatenate([seed, seed]) } + # this is the same example as above, but with a batch size of 2. + + generation_config = Bumblebee.configure(generation_config, max_new_tokens: 3) generate = @@ -146,13 +199,13 @@ defmodule Bumblebee.Text.GenerationTest do assert_equal(token_ids[[0, 0]], 80) # second token_id still 80 as we suppress token_id 79 assert_equal(token_ids[[0, 1]], 80) - # in the next step we increment from 79 to 80 and suppress token_id 80 + # in the next step we increment from 79 to 80 and suppress token_id 80 assert_equal(token_ids[[0, 2]], 1016) # second entry in batch # first token_id still 80 as we suppress token_id 79 assert_equal(token_ids[[1, 0]], 80) - # in the next step we increment from 79 to 80 and suppress token_id 80 + # in the next step we increment from 79 to 80 and suppress token_id 80 assert_equal(token_ids[[1, 1]], 176) end From 0cdc0adfa2c1877e548022d139a18aae12f9135d Mon Sep 17 00:00:00 2001 From: Chris Date: Tue, 21 Oct 2025 17:08:42 +0200 Subject: [PATCH 11/36] renaming suppressed_id -> suppressed_token_id --- test/bumblebee/text/generation_test.exs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index 227cd1b6..25adfb52 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -140,7 +140,7 @@ defmodule Bumblebee.Text.GenerationTest do Bumblebee.Text.Generation.build_generate(model, spec, generation_config, logits_processors: [ &Bumblebee.Text.GenerationTest.StatefulLogitsProcessing.stateful_processor(&1, &2, - initial_suppressed_id: [79] + initial_suppressed_token_id: [79] ) ] ) @@ -185,7 +185,7 @@ defmodule Bumblebee.Text.GenerationTest do Bumblebee.Text.Generation.build_generate(model, spec, generation_config, logits_processors: [ &Bumblebee.Text.GenerationTest.StatefulLogitsProcessing.stateful_processor(&1, &2, - initial_suppressed_id: [78, 79] + initial_suppressed_token_id: [78, 79] ) ] ) @@ -213,21 +213,21 @@ defmodule Bumblebee.Text.GenerationTest do import Nx.Defn deftransform stateful_processor(logits, context, opts \\ []) do - initial_suppressed_ids = Enum.map(opts[:initial_suppressed_id], &List.wrap(&1)) - initial_suppressed_id = Nx.tensor(initial_suppressed_ids) |> Nx.vectorize(:batch) + initial_suppressed_token_ids = Enum.map(opts[:initial_suppressed_token_id], &List.wrap(&1)) + initial_suppressed_token_id = Nx.tensor(initial_suppressed_token_ids) |> Nx.vectorize(:batch) suppressed_id = - context.logits_processor_state[:next_suppressed_id] || initial_suppressed_id + context.logits_processor_state[:next_suppressed_token_id] || initial_suppressed_token_id logits = suppress_id(logits, suppressed_id) - next_suppressed_id = Nx.add(suppressed_id, 1) + next_suppressed_token_id = Nx.add(suppressed_id, 1) context = put_in( context, - [:logits_processor_state, :next_suppressed_id], - next_suppressed_id + [:logits_processor_state, :next_suppressed_token_id], + next_suppressed_token_id ) {logits, context} From cc6d6e30944a4e9070e87d604c6f0bccb020b200 Mon Sep 17 00:00:00 2001 From: Chris Date: Tue, 21 Oct 2025 17:23:49 +0200 Subject: [PATCH 12/36] more comments --- test/bumblebee/text/generation_test.exs | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index 25adfb52..05588d70 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -127,12 +127,15 @@ defmodule Bumblebee.Text.GenerationTest do "seed" => seed } - # in the first decoder test above we got 80, 80, 80 consistently. - # we use the stateful processor below - # it suppresses the given initial id on the first iteration, then increments - #the id to be suppressed on the following iterations. - # So it surpresses 79, 80, ... in the iterations, demonstrating the use - # of the state in a processor. + # We demonstrate the use of the state with the following example of a + # stateful processor (see below). On the first iteration, it suppresses the + # given initial ID, then increments the token ID to be suppressed on the + # following iterations. The ID of the token to be suppressed is passed on + # between iterations using the logits_processor_state. + # + # So invoked with the initial ID of 79, it suppresses 79, 80, 81, ... in + # the subsequent iterations, demonstrating the use of the state in a + # logits processor. generation_config = Bumblebee.configure(generation_config, max_new_tokens: 2) @@ -145,9 +148,16 @@ defmodule Bumblebee.Text.GenerationTest do ] ) + # The result without the logits processor would be, as with the first + # decoder test above: 80, 80, 80. + # + # Now, with the processor below, we expect no change (suppressed token ID is + # 79), then a change to another random token ID (176) as the suppressed + # token ID is incremented from 79 to 80, disallowing the previous most + # likely token ID (80) from being selected. + %{token_ids: token_ids} = generate.(params, inputs) - # result without logit processor: 80, 80, 80 # first token_id still 80 as we suppress token_id 79 assert_equal(token_ids[[0,0]], 80) From 3816e7c4f5906ec06571fce60b7c6515ef009066 Mon Sep 17 00:00:00 2001 From: Chris Date: Thu, 23 Oct 2025 14:39:18 +0200 Subject: [PATCH 13/36] changed to to demonstrate stack functionality --- test/bumblebee/text/generation_test.exs | 82 ++++++++++++------------- 1 file changed, 41 insertions(+), 41 deletions(-) diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index 05588d70..24f7dc6d 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -128,12 +128,12 @@ defmodule Bumblebee.Text.GenerationTest do } # We demonstrate the use of the state with the following example of a - # stateful processor (see below). On the first iteration, it suppresses the - # given initial ID, then increments the token ID to be suppressed on the - # following iterations. The ID of the token to be suppressed is passed on + # stateful processor (see below). On the first iteration, it enforces the + # given initial ID, then increments the token ID to be enforced on the + # following iterations. The ID of the token to be enforced is passed on # between iterations using the logits_processor_state. # - # So invoked with the initial ID of 79, it suppresses 79, 80, 81, ... in + # So invoked with the initial ID of 79, it enforces 79, 80, 81, ... in # the subsequent iterations, demonstrating the use of the state in a # logits processor. @@ -143,28 +143,23 @@ defmodule Bumblebee.Text.GenerationTest do Bumblebee.Text.Generation.build_generate(model, spec, generation_config, logits_processors: [ &Bumblebee.Text.GenerationTest.StatefulLogitsProcessing.stateful_processor(&1, &2, - initial_suppressed_token_id: [79] + initial_enforced_token_ids: [79] ) ] ) # The result without the logits processor would be, as with the first - # decoder test above: 80, 80, 80. + # decoder test above, [80, 80, 80]. # - # Now, with the processor below, we expect no change (suppressed token ID is - # 79), then a change to another random token ID (176) as the suppressed - # token ID is incremented from 79 to 80, disallowing the previous most - # likely token ID (80) from being selected. + # Now, with the processor below, we expect the sequence of [79, 80, 81 ..] %{token_ids: token_ids} = generate.(params, inputs) + # first token_id should be 79 as we enforce token_id 79 + assert_equal(token_ids[[0,0]], 79) - # first token_id still 80 as we suppress token_id 79 - assert_equal(token_ids[[0,0]], 80) - # in the next step we increment from 79 to 80 and suppress token_id 80, the - #result is 176 as that is the next likelihood in the logits. - - assert_equal(token_ids[[0,1]], 176) + # in the next step we increment from 79 to 80 and enforce token_id 80 + assert_equal(token_ids[[0,1]], 80) end test "with stateful logits processor with batch size of 2" do @@ -195,7 +190,7 @@ defmodule Bumblebee.Text.GenerationTest do Bumblebee.Text.Generation.build_generate(model, spec, generation_config, logits_processors: [ &Bumblebee.Text.GenerationTest.StatefulLogitsProcessing.stateful_processor(&1, &2, - initial_suppressed_token_id: [78, 79] + initial_enforced_token_ids: [78, 20] ) ] ) @@ -205,50 +200,55 @@ defmodule Bumblebee.Text.GenerationTest do # result without logit processor: 80, 80, 80 # first entry in batch - # first token_id still 80 as we suppress token_id 78 - assert_equal(token_ids[[0, 0]], 80) - # second token_id still 80 as we suppress token_id 79 - assert_equal(token_ids[[0, 1]], 80) - # in the next step we increment from 79 to 80 and suppress token_id 80 - assert_equal(token_ids[[0, 2]], 1016) + # first token_id should be 78 as we enforce token_id 78 on the first + # iteration + assert_equal(token_ids[[0, 0]], 78) + + # second should be 79 as we increment the enforced token_id from 78 to 79 + assert_equal(token_ids[[0, 1]], 79) + + # in the next step we increment from 79 to 80 and enforce token_id 80 + assert_equal(token_ids[[0, 2]], 80) # second entry in batch - # first token_id still 80 as we suppress token_id 79 - assert_equal(token_ids[[1, 0]], 80) - # in the next step we increment from 79 to 80 and suppress token_id 80 - assert_equal(token_ids[[1, 1]], 176) + # first token_id is 20 as we enforce token_id 20 on the first iteration + assert_equal(token_ids[[1, 0]], 20) + + # in the next step we increment from 20 to 21 and enforce token_id 21 + assert_equal(token_ids[[1, 1]], 21) end defmodule StatefulLogitsProcessing do import Nx.Defn deftransform stateful_processor(logits, context, opts \\ []) do - initial_suppressed_token_ids = Enum.map(opts[:initial_suppressed_token_id], &List.wrap(&1)) - initial_suppressed_token_id = Nx.tensor(initial_suppressed_token_ids) |> Nx.vectorize(:batch) + # initial_enforced_token_ids = opts[:initial_enforced_token_ids] + initial_enforced_token_ids = Enum.map(opts[:initial_enforced_token_ids], &List.wrap(&1)) + # Enum.map(opts[:initial_suppressed_token_id], &List.wrap(&1)) + # pick the actual token id from the batch + initial_enforced_token_id = Nx.tensor(initial_enforced_token_ids) |> Nx.vectorize(:batch) - suppressed_id = - context.logits_processor_state[:next_suppressed_token_id] || initial_suppressed_token_id + enforced_token_id = + context.logits_processor_state[:next_enforced_token_id] || initial_enforced_token_id - logits = suppress_id(logits, suppressed_id) + logits = enforce_token(logits, enforced_token_id) - next_suppressed_token_id = Nx.add(suppressed_id, 1) + next_enforced_token_id = Nx.add(enforced_token_id, 1) context = put_in( context, - [:logits_processor_state, :next_suppressed_token_id], - next_suppressed_token_id + [:logits_processor_state, :next_enforced_token_id], + next_enforced_token_id ) {logits, context} end - defnp suppress_id(logits, id) do - Nx.indexed_put( - logits, - id, - Nx.Constants.neg_infinity(Nx.type(logits)) - ) + defnp enforce_token(logits, token_id) do + logits + |> Nx.fill(Nx.Constants.neg_infinity(), type: Nx.type(logits)) + |> Nx.indexed_put(token_id, Nx.tensor(0, type: Nx.type(logits))) end end end From fe587121faa1e4464576b48bd4c3ed14be97f452 Mon Sep 17 00:00:00 2001 From: Chris Date: Thu, 23 Oct 2025 14:41:58 +0200 Subject: [PATCH 14/36] merged tests --- test/bumblebee/text/generation_test.exs | 25 +++++++------------------ 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index 24f7dc6d..2a3b37d8 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -121,6 +121,9 @@ defmodule Bumblebee.Text.GenerationTest do attention_mask = Nx.tensor([[0, 0, 1, 1, 1, 1, 1, 1, 1, 1]]) seed = Nx.tensor([0]) + ######################################################### + # batch size of 1 + inputs = %{ "input_ids" => input_ids, "attention_mask" => attention_mask, @@ -160,20 +163,9 @@ defmodule Bumblebee.Text.GenerationTest do # in the next step we increment from 79 to 80 and enforce token_id 80 assert_equal(token_ids[[0,1]], 80) - end - test "with stateful logits processor with batch size of 2" do - assert {:ok, %{model: model, params: params, spec: spec}} = - Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"}) - - {:ok, generation_config} = - Bumblebee.load_generation_config({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"}) - - assert %Bumblebee.Text.Gpt2{architecture: :for_causal_language_modeling} = spec - - input_ids = Nx.tensor([[0, 0, 10, 20, 30, 40, 50, 60, 70, 80]]) - attention_mask = Nx.tensor([[0, 0, 1, 1, 1, 1, 1, 1, 1, 1]]) - seed = Nx.tensor([0]) + ######################################################### + # batch size of 2 inputs = %{ "input_ids" => Nx.Batch.concatenate([input_ids, input_ids]), @@ -222,14 +214,11 @@ defmodule Bumblebee.Text.GenerationTest do import Nx.Defn deftransform stateful_processor(logits, context, opts \\ []) do - # initial_enforced_token_ids = opts[:initial_enforced_token_ids] initial_enforced_token_ids = Enum.map(opts[:initial_enforced_token_ids], &List.wrap(&1)) - # Enum.map(opts[:initial_suppressed_token_id], &List.wrap(&1)) - # pick the actual token id from the batch - initial_enforced_token_id = Nx.tensor(initial_enforced_token_ids) |> Nx.vectorize(:batch) + initial_enforced_batch_token_id = Nx.tensor(initial_enforced_token_ids) |> Nx.vectorize(:batch) enforced_token_id = - context.logits_processor_state[:next_enforced_token_id] || initial_enforced_token_id + context.logits_processor_state[:next_enforced_token_id] || initial_enforced_batch_token_id logits = enforce_token(logits, enforced_token_id) From c97890aea59b3e0fe5279ebb81e94dd65161653c Mon Sep 17 00:00:00 2001 From: Chris Date: Thu, 23 Oct 2025 14:44:27 +0200 Subject: [PATCH 15/36] removed test for processor only used in test --- .../generation/logits_processing_test.exs | 47 ------------------- 1 file changed, 47 deletions(-) diff --git a/test/bumblebee/text/generation/logits_processing_test.exs b/test/bumblebee/text/generation/logits_processing_test.exs index 1aefabdd..190e97c4 100644 --- a/test/bumblebee/text/generation/logits_processing_test.exs +++ b/test/bumblebee/text/generation/logits_processing_test.exs @@ -5,53 +5,6 @@ defmodule Bumblebee.Text.Generation.LogitsProcessingTest do alias Bumblebee.Text.Generation.LogitsProcessing - describe "stateful logits processors" do - defmodule StatefulLogitsProcessing do - import Nx.Defn - - deftransform stateful_processor(logits, context, opts) do - initial_suppressed_token_index = Nx.tensor([opts[:initial_suppressed_token_index]]) - - suppressed_index = - context.logits_processor_state[:next_suppressed_token_index] || initial_suppressed_token_index - - values = - Nx.broadcast(Nx.Constants.neg_infinity(Nx.type(logits)), Nx.size(suppressed_index)) - - logits = Nx.indexed_put(logits, suppressed_index, values) - - next_suppressed_token_index = Nx.add(suppressed_index, Nx.tensor([1])) - - context = - put_in( - context, - [:logits_processor_state, :next_suppressed_token_index], - next_suppressed_token_index - ) - - {logits, context} - end - end - - test "can register and modify state" do - logits = Nx.tensor([1.0, 2.0, 3.0, 4.0]) - - context = context([1, 0, 0, 0]) - - {logits, context} = - StatefulLogitsProcessing.stateful_processor(logits, context, initial_suppressed_token_index: 0) - - assert_equal(logits, Nx.tensor([:neg_infinity, 2.0, 3.0, 4.0])) - assert_equal(context.logits_processor_state.next_suppressed_token_index, Nx.tensor([1])) - - {logits, context} = - StatefulLogitsProcessing.stateful_processor(logits, context, initial_suppressed_token_index: 0) - - assert_equal(logits, Nx.tensor([:neg_infinity, :neg_infinity, 3.0, 4.0])) - assert_equal(context.logits_processor_state.next_suppressed_token_index, Nx.tensor([2])) - end - end - describe "suppressed_tokens_processor/3" do test "ignores the given tokens" do logits = Nx.tensor([1.0, 2.0, 3.0, 4.0]) From fbf5ef30dc1bd193363e72435dbdd51de025e56f Mon Sep 17 00:00:00 2001 From: Chris Date: Fri, 24 Oct 2025 15:42:36 +0200 Subject: [PATCH 16/36] introduces LogitsProcessor module --- lib/bumblebee.ex | 30 ++++ lib/bumblebee/logits_processor.ex | 39 ++++++ lib/bumblebee/text/generation.ex | 132 ++++++++++++------ .../generation/stateless_logits_processor.ex | 31 ++++ test/bumblebee/text/generation_test.exs | 68 +++++---- 5 files changed, 231 insertions(+), 69 deletions(-) create mode 100644 lib/bumblebee/logits_processor.ex create mode 100644 lib/bumblebee/text/generation/stateless_logits_processor.ex diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index 51f2330f..06ae8d09 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -1083,6 +1083,36 @@ defmodule Bumblebee do end end + @doc """ + Initializes state for a new logits processor. + + Returns `state`, which is an opaque `Nx.Container`, and it is then + passed to and returned from `process/4`. + """ + @doc type: :logits_processor + @spec logits_processor_init( + Bumblebee.LogitsProcessor.t(), + context :: term() + ) :: Bumblebee.LogitsProcessor.state() + def logits_processor_init(%module{} = logits_processor, context) do + module.init(logits_processor, context) + end + + @doc """ + Processes logits, applying specific rules. Receives context, state and + logits, and returns updated logits and state. + """ + @doc type: :logits_processor + @spec logits_processor_process( + Bumblebee.LogitsProcessor.t(), + Bumblebee.LogitsProcessor.state(), + logits :: Nx.Tensor.t(), + context :: term() + ) :: {Bumblebee.LogitsProcessor.state(), logits :: Nx.Tensor.t()} + def logits_processor_process(%module{} = logits_processor, state, logits, context) do + module.process(logits_processor, state, logits, context) + end + @doc """ Initializes state for a new scheduler loop. diff --git a/lib/bumblebee/logits_processor.ex b/lib/bumblebee/logits_processor.ex new file mode 100644 index 00000000..83815dcb --- /dev/null +++ b/lib/bumblebee/logits_processor.ex @@ -0,0 +1,39 @@ +defmodule Bumblebee.LogitsProcessor do + @moduledoc """ + An interface for configuring and using logits processors. + + Logits processors are used during autoregressive generation to modify + predicted scores at each generation step. This allows for applying + certain rules to the model output to control which tokens are picked + at each generation step, and which are not. + + Every module implementing this behaviour is expected to also define + a configuration struct. + """ + + @type t :: Bumblebee.Configurable.t() + + @type state :: Nx.Container.t() + + + @doc """ + Initializes state for a new logits processor. + + Returns `state`, which is an opaque `Nx.Container`, and it is then + passed to and returned from `process/2`. + + Oftentimes logits processors are stateless, in which case this + function can return an empty container, such as `{}`. + """ + @callback init(t(), any()) :: state() + + @doc """ + Processes logits, applying specific rules. + """ + @callback process( + t(), + state(), + logits :: Nx.Tensor.t(), + context:: term() + ) :: {logits :: Nx.Tensor.t(), state :: map()} +end diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index 51ccb58a..20a792ce 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -164,13 +164,14 @@ defmodule Bumblebee.Text.Generation do {_init_fun, predict_fun} = Axon.build(model, global_layer_options: global_layer_options) - logits_processor_fun = get_logits_processor(min_length_fun, config, opts[:logits_processors]) + {logits_processor_init_fun, logits_processor_process_fun} = get_logits_processor(min_length_fun, config, opts[:logits_processors]) &generate_impl( &2, predict_fun, &1, - logits_processor_fun, + logits_processor_init_fun, + logits_processor_process_fun, prepare_inputs_fun, update_inputs_fun, traverse_cache_fun, @@ -386,22 +387,51 @@ defmodule Bumblebee.Text.Generation do [] end ++ logits_processors - fn logits, context -> - for processor <- processors, processor, reduce: {logits, context} do - {logits, context} -> - case processor.(logits, context) do - {logits, new_context} -> {logits, new_context} - logits -> {logits, context} - end - end + + processors = + processors + |> Enum.filter(fn processor -> processor != nil end) + |> Enum.map(fn processor -> + if is_function(processor, 2) do + gna = %Bumblebee.Text.Generation.StatelessLogitsProcessor{fun: processor} + if gna.fun == nil do raise "hell" end + gna + else + if processor == nil do raise "heaven" end + processor + end + end) + + + init_fun = fn context -> + Enum.reduce(processors, %{}, fn processor, state_acc -> + state = Bumblebee.logits_processor_init(processor, context) + Map.merge(state_acc, state) + end) + end + + process_fun = fn logits, context, state -> + Enum.reduce(processors, {logits, state}, fn processor, {logits, state} -> + Bumblebee.logits_processor_process(processor, state, logits, context) + end) end + + {init_fun, process_fun} + + # Technically, the :logits_processors options is public API, but we can make it backward-compatible. For example, we + # can define %Bumblebee.Text.Generation.StatelessLogitsProcessor{fun: fun}, where the state is always empty and + # process just invokes the fun. I would even use that for the built-in processors, so that we don't need to define + # a bunch of new modules. + # + # %Bumblebee.Text.Generation.StatelessLogitsProcessor{fun: fun} end defnp generate_impl( inputs, predict_fun, params, - logits_processor_fun, + logits_processor_init_fun, + logits_processor_process_fun, prepare_inputs_fun, update_inputs_fun, traverse_cache_fun, @@ -431,7 +461,8 @@ defmodule Bumblebee.Text.Generation do padded_batch_item?, predict_fun, params, - logits_processor_fun, + logits_processor_init_fun, + logits_processor_process_fun, update_inputs_fun, merge_options([max_length: max_length], opts) ) @@ -443,7 +474,8 @@ defmodule Bumblebee.Text.Generation do padded_batch_item?, predict_fun, params, - logits_processor_fun, + logits_processor_init_fun, + logits_processor_process_fun, update_inputs_fun, traverse_cache_fun, merge_options( @@ -460,7 +492,8 @@ defmodule Bumblebee.Text.Generation do predict_fun, params, seed, - logits_processor_fun, + logits_processor_init_fun, + logits_processor_process_fun, update_inputs_fun, merge_options([max_length: max_length], opts) ) @@ -489,7 +522,8 @@ defmodule Bumblebee.Text.Generation do padded_batch_item?, predict_fun, params, - logits_processor_fun, + logits_processor_init_fun, + logits_processor_process_fun, update_inputs_fun, opts \\ [] ) do @@ -497,7 +531,8 @@ defmodule Bumblebee.Text.Generation do pad_token_id = opts[:pad_token_id] eos_token_id = opts[:eos_token_id] - state = init_sequences(decoder_input_ids, padded_batch_item?, max_length, pad_token_id) + + state = init_sequences(decoder_input_ids, padded_batch_item?, max_length, pad_token_id, logits_processor_init_fun) # The loop works with inputs of length 1, so if the initial input # is longer, we make the initial pass outside @@ -508,7 +543,7 @@ defmodule Bumblebee.Text.Generation do inputs, predict_fun, params, - logits_processor_fun, + logits_processor_process_fun, update_inputs_fun, pad_token_id: pad_token_id, eos_token_id: eos_token_id @@ -525,7 +560,7 @@ defmodule Bumblebee.Text.Generation do inputs, predict_fun, params, - logits_processor_fun, + logits_processor_process_fun, update_inputs_fun, pad_token_id: pad_token_id, eos_token_id: eos_token_id @@ -537,7 +572,7 @@ defmodule Bumblebee.Text.Generation do state end - defnp init_sequences(decoder_input_ids, padded_batch_item?, max_length, pad_token_id) do + defnp init_sequences(decoder_input_ids, padded_batch_item?, max_length, pad_token_id, logits_processor_init_fun) do {batch_size, length} = Nx.shape(decoder_input_ids) sequences = Nx.broadcast(pad_token_id, {batch_size, max_length}) @@ -549,6 +584,12 @@ defmodule Bumblebee.Text.Generation do # they could produce arbitrary tokens until we reach max length. finished_length = Nx.select(padded_batch_item?, 1, 0) + context = %{ + sequences: sequences, + input_length: length, + length: length, + } + %{ sequences: sequences, input_length: length, @@ -556,7 +597,7 @@ defmodule Bumblebee.Text.Generation do finished_length: finished_length, # The ignored return value that we attach all hooks to ignored: Nx.broadcast(0, {batch_size}), - logits_processor_states: %{} + logits_processor_state: logits_processor_init_fun.(context) } end @@ -569,7 +610,7 @@ defmodule Bumblebee.Text.Generation do inputs, predict_fun, params, - logits_processor_fun, + logits_processor_process_fun, update_inputs_fun, opts ) do @@ -579,7 +620,7 @@ defmodule Bumblebee.Text.Generation do outputs = predict_fun.(params, inputs) logits = outputs.logits[[.., -1]] - {logits, state} = batch_process_logits(logits_processor_fun, logits, state) + {logits, state} = batch_process_logits(logits_processor_process_fun, logits, state) token_id = Nx.argmax(logits, axis: -1) state = update_sequences(state, token_id, pad_token_id, eos_token_id) @@ -636,25 +677,26 @@ defmodule Bumblebee.Text.Generation do end end - defnp batch_process_logits(logits_processor_fun, logits, state) do + defnp batch_process_logits(logits_processor_process_fun, logits, state) do + logits = Nx.vectorize(logits, :batch) - {logits, new_context} = - logits_processor_fun.(logits, %{ - sequence: Nx.vectorize(state.sequences, :batch), - length: state.length, - input_length: state.input_length, - logits_processor_state: Nx.vectorize(state.logits_processor_states, :batch) - }) + context = %{ + sequences: Nx.vectorize(state.sequences, :batch), + length: state.length, + input_length: state.input_length + } + + {logits, new_logits_processor_state} = + logits_processor_process_fun.(logits, context, Nx.vectorize(state.logits_processor_state, :batch)) logits = Nx.devectorize(logits, keep_names: false) - logits_processor_states = - Nx.devectorize(new_context.logits_processor_state, keep_names: false) + logits_processor_state = + Nx.devectorize(new_logits_processor_state, keep_names: false) - sequences = Nx.devectorize(new_context.sequence, keep_names: false) - {logits, %{state | sequences: sequences, logits_processor_states: logits_processor_states}} + {logits, %{state | logits_processor_state: logits_processor_state}} end # Contrastive search @@ -665,7 +707,8 @@ defmodule Bumblebee.Text.Generation do padded_batch_item?, predict_fun, params, - logits_processor_fun, + logits_processor_init_fun, + logits_processor_process_fun, update_inputs_fun, traverse_cache_fun, opts \\ [] @@ -676,7 +719,7 @@ defmodule Bumblebee.Text.Generation do top_k = opts[:top_k] penalty_alpha = opts[:penalty_alpha] - state = init_sequences(decoder_input_ids, padded_batch_item?, max_length, pad_token_id) + state = init_sequences(decoder_input_ids, padded_batch_item?, max_length, pad_token_id, logits_processor_init_fun) # Step (1) # Initial pass to obtain hidden state and expand inputs to top-k @@ -699,7 +742,7 @@ defmodule Bumblebee.Text.Generation do joint_hidden_state = Nx.put_slice(joint_hidden_state, [0, 0, 0], initial_hidden_state) logits = outputs.logits[[.., -1]] - {logits, state} = batch_process_logits(logits_processor_fun, logits, state) + {logits, state} = batch_process_logits(logits_processor_process_fun, logits, state) scores = Axon.Activations.softmax(logits, axis: -1) {top_k_scores, top_k_token_ids} = Nx.top_k(scores, k: top_k) @@ -742,7 +785,7 @@ defmodule Bumblebee.Text.Generation do logits = outputs.logits[[.., -1]] logits = Utils.Nx.chunked_take(logits, top_k, selected_idx) - {logits, state} = batch_process_logits(logits_processor_fun, logits, state) + {logits, state} = batch_process_logits(logits_processor_process_fun, logits, state) scores = Axon.Activations.softmax(logits, axis: -1) {top_k_scores, top_k_token_ids} = Nx.top_k(scores, k: top_k) @@ -832,7 +875,8 @@ defmodule Bumblebee.Text.Generation do predict_fun, params, seed, - logits_processor_fun, + logits_processor_init_fun, + logits_processor_process_fun, update_inputs_fun, opts \\ [] ) do @@ -840,7 +884,7 @@ defmodule Bumblebee.Text.Generation do pad_token_id = opts[:pad_token_id] eos_token_id = opts[:eos_token_id] - state = init_sequences(decoder_input_ids, padded_batch_item?, max_length, pad_token_id) + state = init_sequences(decoder_input_ids, padded_batch_item?, max_length, pad_token_id, logits_processor_init_fun) prng_key = seed |> Nx.vectorize(:batch) |> Nx.Random.key() @@ -854,7 +898,7 @@ defmodule Bumblebee.Text.Generation do predict_fun, params, prng_key, - logits_processor_fun, + logits_processor_process_fun, update_inputs_fun, pad_token_id: pad_token_id, eos_token_id: eos_token_id @@ -872,7 +916,7 @@ defmodule Bumblebee.Text.Generation do predict_fun, params, prng_key, - logits_processor_fun, + logits_processor_process_fun, update_inputs_fun, pad_token_id: pad_token_id, eos_token_id: eos_token_id @@ -890,7 +934,7 @@ defmodule Bumblebee.Text.Generation do predict_fun, params, prng_key, - logits_processor_fun, + logits_processor_process_fun, update_inputs_fun, opts \\ [] ) do @@ -903,7 +947,7 @@ defmodule Bumblebee.Text.Generation do outputs = predict_fun.(params, inputs) logits = outputs.logits[[.., -1]] - {logits, state} = batch_process_logits(logits_processor_fun, logits, state) + {logits, state} = batch_process_logits(logits_processor_process_fun, logits, state) scores = Axon.Activations.softmax(logits) token_id = batched_choice(key, scores) diff --git a/lib/bumblebee/text/generation/stateless_logits_processor.ex b/lib/bumblebee/text/generation/stateless_logits_processor.ex new file mode 100644 index 00000000..16a7053a --- /dev/null +++ b/lib/bumblebee/text/generation/stateless_logits_processor.ex @@ -0,0 +1,31 @@ +defmodule Bumblebee.Text.Generation.StatelessLogitsProcessor do + @moduledoc false + + @behaviour Bumblebee.Configurable + @behaviour Bumblebee.LogitsProcessor + + options = [ + fun: [ + default: nil, + doc: "a state-less function that is applied to the logits" + ] + ] + + defstruct Bumblebee.Shared.option_defaults(options) + + @impl Bumblebee.Configurable + def config(logits_processor, opts) do + Bumblebee.Shared.put_config_attrs(logits_processor, opts) + end + + @impl Bumblebee.LogitsProcessor + def init(_logits_processor, _context) do + %{} + end + + @impl Bumblebee.LogitsProcessor + def process(logits_processor, state, logits, context) do + {logits_processor.fun.(logits, context), state} + end + +end diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index 2a3b37d8..a2dc47b1 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -107,8 +107,7 @@ defmodule Bumblebee.Text.GenerationTest do assert_equal(token_ids, Nx.tensor([[80, 1023, 1023]])) end - - test "with stateful logits processor with batch size of 1" do + test "with stateful logits processor with different batch sizes" do assert {:ok, %{model: model, params: params, spec: spec}} = Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"}) @@ -144,10 +143,9 @@ defmodule Bumblebee.Text.GenerationTest do generate = Bumblebee.Text.Generation.build_generate(model, spec, generation_config, + # ToDo Bumblee.configure() logits_processors: [ - &Bumblebee.Text.GenerationTest.StatefulLogitsProcessing.stateful_processor(&1, &2, - initial_enforced_token_ids: [79] - ) + Bumblebee.configure(Bumblebee.Text.GenerationTest.StatefulLogitsProcessing, initial_enforced_token_ids: [79]) ] ) @@ -159,10 +157,10 @@ defmodule Bumblebee.Text.GenerationTest do %{token_ids: token_ids} = generate.(params, inputs) # first token_id should be 79 as we enforce token_id 79 - assert_equal(token_ids[[0,0]], 79) + assert_equal(token_ids[[0, 0]], 79) # in the next step we increment from 79 to 80 and enforce token_id 80 - assert_equal(token_ids[[0,1]], 80) + assert_equal(token_ids[[0, 1]], 80) ######################################################### # batch size of 2 @@ -175,15 +173,12 @@ defmodule Bumblebee.Text.GenerationTest do # this is the same example as above, but with a batch size of 2. - generation_config = Bumblebee.configure(generation_config, max_new_tokens: 3) generate = Bumblebee.Text.Generation.build_generate(model, spec, generation_config, logits_processors: [ - &Bumblebee.Text.GenerationTest.StatefulLogitsProcessing.stateful_processor(&1, &2, - initial_enforced_token_ids: [78, 20] - ) + Bumblebee.configure(Bumblebee.Text.GenerationTest.StatefulLogitsProcessing, initial_enforced_token_ids: [78, 20]) ] ) @@ -211,27 +206,50 @@ defmodule Bumblebee.Text.GenerationTest do end defmodule StatefulLogitsProcessing do + @moduledoc false + import Nx.Defn - deftransform stateful_processor(logits, context, opts \\ []) do - initial_enforced_token_ids = Enum.map(opts[:initial_enforced_token_ids], &List.wrap(&1)) - initial_enforced_batch_token_id = Nx.tensor(initial_enforced_token_ids) |> Nx.vectorize(:batch) + @behaviour Bumblebee.Configurable + @behaviour Bumblebee.LogitsProcessor + + options = [ + initial_enforced_token_ids: [ + default: [], + doc: "A list of token ids to enforce on the first iteration" + ] + ] - enforced_token_id = - context.logits_processor_state[:next_enforced_token_id] || initial_enforced_batch_token_id + defstruct Bumblebee.Shared.option_defaults(options) - logits = enforce_token(logits, enforced_token_id) + @impl Bumblebee.Configurable + def config(logits_processor, opts) do + Bumblebee.Shared.put_config_attrs(logits_processor, opts) + end + + @impl Bumblebee.LogitsProcessor + def init(logits_processor, _context) do + initial_enforced_token_ids = + Enum.map(logits_processor.initial_enforced_token_ids, &List.wrap(&1)) + + initial_enforced_batch_token_id = + Nx.tensor(initial_enforced_token_ids) + %{ + sfp_state: %{ + next_enforced_token_id: initial_enforced_batch_token_id + } + } + end - next_enforced_token_id = Nx.add(enforced_token_id, 1) + @impl Bumblebee.LogitsProcessor + def process(_logits_processor, state, logits, _context) do + sfp_state = state.sfp_state + logits = enforce_token(logits, sfp_state.next_enforced_token_id) - context = - put_in( - context, - [:logits_processor_state, :next_enforced_token_id], - next_enforced_token_id - ) + sfp_state = %{sfp_state | next_enforced_token_id: Nx.add(sfp_state.next_enforced_token_id, 1)} + state = %{state | sfp_state: sfp_state} - {logits, context} + {logits, state} end defnp enforce_token(logits, token_id) do From dfa223c6661da4eb4cac27e38d1c7ab2af9e5347 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Mon, 27 Oct 2025 13:58:59 +0100 Subject: [PATCH 17/36] clean up --- lib/bumblebee/text/generation.ex | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index 20a792ce..a7b58112 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -393,11 +393,8 @@ defmodule Bumblebee.Text.Generation do |> Enum.filter(fn processor -> processor != nil end) |> Enum.map(fn processor -> if is_function(processor, 2) do - gna = %Bumblebee.Text.Generation.StatelessLogitsProcessor{fun: processor} - if gna.fun == nil do raise "hell" end - gna + %Bumblebee.Text.Generation.StatelessLogitsProcessor{fun: processor} else - if processor == nil do raise "heaven" end processor end end) @@ -417,13 +414,6 @@ defmodule Bumblebee.Text.Generation do end {init_fun, process_fun} - - # Technically, the :logits_processors options is public API, but we can make it backward-compatible. For example, we - # can define %Bumblebee.Text.Generation.StatelessLogitsProcessor{fun: fun}, where the state is always empty and - # process just invokes the fun. I would even use that for the built-in processors, so that we don't need to define - # a bunch of new modules. - # - # %Bumblebee.Text.Generation.StatelessLogitsProcessor{fun: fun} end defnp generate_impl( From 9098bdaa4ac675194db5d4e6cf4c3f9f4918665f Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Mon, 27 Oct 2025 13:59:31 +0100 Subject: [PATCH 18/36] mix format --- lib/bumblebee/logits_processor.ex | 3 +- lib/bumblebee/text/generation.ex | 44 ++++++++++++++----- .../generation/stateless_logits_processor.ex | 1 - test/bumblebee/text/generation_test.exs | 15 +++++-- 4 files changed, 47 insertions(+), 16 deletions(-) diff --git a/lib/bumblebee/logits_processor.ex b/lib/bumblebee/logits_processor.ex index 83815dcb..6e9002dd 100644 --- a/lib/bumblebee/logits_processor.ex +++ b/lib/bumblebee/logits_processor.ex @@ -15,7 +15,6 @@ defmodule Bumblebee.LogitsProcessor do @type state :: Nx.Container.t() - @doc """ Initializes state for a new logits processor. @@ -34,6 +33,6 @@ defmodule Bumblebee.LogitsProcessor do t(), state(), logits :: Nx.Tensor.t(), - context:: term() + context :: term() ) :: {logits :: Nx.Tensor.t(), state :: map()} end diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index a7b58112..6156bd46 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -164,7 +164,8 @@ defmodule Bumblebee.Text.Generation do {_init_fun, predict_fun} = Axon.build(model, global_layer_options: global_layer_options) - {logits_processor_init_fun, logits_processor_process_fun} = get_logits_processor(min_length_fun, config, opts[:logits_processors]) + {logits_processor_init_fun, logits_processor_process_fun} = + get_logits_processor(min_length_fun, config, opts[:logits_processors]) &generate_impl( &2, @@ -387,7 +388,6 @@ defmodule Bumblebee.Text.Generation do [] end ++ logits_processors - processors = processors |> Enum.filter(fn processor -> processor != nil end) @@ -399,7 +399,6 @@ defmodule Bumblebee.Text.Generation do end end) - init_fun = fn context -> Enum.reduce(processors, %{}, fn processor, state_acc -> state = Bumblebee.logits_processor_init(processor, context) @@ -521,8 +520,14 @@ defmodule Bumblebee.Text.Generation do pad_token_id = opts[:pad_token_id] eos_token_id = opts[:eos_token_id] - - state = init_sequences(decoder_input_ids, padded_batch_item?, max_length, pad_token_id, logits_processor_init_fun) + state = + init_sequences( + decoder_input_ids, + padded_batch_item?, + max_length, + pad_token_id, + logits_processor_init_fun + ) # The loop works with inputs of length 1, so if the initial input # is longer, we make the initial pass outside @@ -562,7 +567,13 @@ defmodule Bumblebee.Text.Generation do state end - defnp init_sequences(decoder_input_ids, padded_batch_item?, max_length, pad_token_id, logits_processor_init_fun) do + defnp init_sequences( + decoder_input_ids, + padded_batch_item?, + max_length, + pad_token_id, + logits_processor_init_fun + ) do {batch_size, length} = Nx.shape(decoder_input_ids) sequences = Nx.broadcast(pad_token_id, {batch_size, max_length}) @@ -577,7 +588,7 @@ defmodule Bumblebee.Text.Generation do context = %{ sequences: sequences, input_length: length, - length: length, + length: length } %{ @@ -668,7 +679,6 @@ defmodule Bumblebee.Text.Generation do end defnp batch_process_logits(logits_processor_process_fun, logits, state) do - logits = Nx.vectorize(logits, :batch) context = %{ @@ -709,7 +719,14 @@ defmodule Bumblebee.Text.Generation do top_k = opts[:top_k] penalty_alpha = opts[:penalty_alpha] - state = init_sequences(decoder_input_ids, padded_batch_item?, max_length, pad_token_id, logits_processor_init_fun) + state = + init_sequences( + decoder_input_ids, + padded_batch_item?, + max_length, + pad_token_id, + logits_processor_init_fun + ) # Step (1) # Initial pass to obtain hidden state and expand inputs to top-k @@ -874,7 +891,14 @@ defmodule Bumblebee.Text.Generation do pad_token_id = opts[:pad_token_id] eos_token_id = opts[:eos_token_id] - state = init_sequences(decoder_input_ids, padded_batch_item?, max_length, pad_token_id, logits_processor_init_fun) + state = + init_sequences( + decoder_input_ids, + padded_batch_item?, + max_length, + pad_token_id, + logits_processor_init_fun + ) prng_key = seed |> Nx.vectorize(:batch) |> Nx.Random.key() diff --git a/lib/bumblebee/text/generation/stateless_logits_processor.ex b/lib/bumblebee/text/generation/stateless_logits_processor.ex index 16a7053a..fb05a85f 100644 --- a/lib/bumblebee/text/generation/stateless_logits_processor.ex +++ b/lib/bumblebee/text/generation/stateless_logits_processor.ex @@ -27,5 +27,4 @@ defmodule Bumblebee.Text.Generation.StatelessLogitsProcessor do def process(logits_processor, state, logits, context) do {logits_processor.fun.(logits, context), state} end - end diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index a2dc47b1..7397d1ec 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -145,7 +145,9 @@ defmodule Bumblebee.Text.GenerationTest do Bumblebee.Text.Generation.build_generate(model, spec, generation_config, # ToDo Bumblee.configure() logits_processors: [ - Bumblebee.configure(Bumblebee.Text.GenerationTest.StatefulLogitsProcessing, initial_enforced_token_ids: [79]) + Bumblebee.configure(Bumblebee.Text.GenerationTest.StatefulLogitsProcessing, + initial_enforced_token_ids: [79] + ) ] ) @@ -178,7 +180,9 @@ defmodule Bumblebee.Text.GenerationTest do generate = Bumblebee.Text.Generation.build_generate(model, spec, generation_config, logits_processors: [ - Bumblebee.configure(Bumblebee.Text.GenerationTest.StatefulLogitsProcessing, initial_enforced_token_ids: [78, 20]) + Bumblebee.configure(Bumblebee.Text.GenerationTest.StatefulLogitsProcessing, + initial_enforced_token_ids: [78, 20] + ) ] ) @@ -234,6 +238,7 @@ defmodule Bumblebee.Text.GenerationTest do initial_enforced_batch_token_id = Nx.tensor(initial_enforced_token_ids) + %{ sfp_state: %{ next_enforced_token_id: initial_enforced_batch_token_id @@ -246,7 +251,11 @@ defmodule Bumblebee.Text.GenerationTest do sfp_state = state.sfp_state logits = enforce_token(logits, sfp_state.next_enforced_token_id) - sfp_state = %{sfp_state | next_enforced_token_id: Nx.add(sfp_state.next_enforced_token_id, 1)} + sfp_state = %{ + sfp_state + | next_enforced_token_id: Nx.add(sfp_state.next_enforced_token_id, 1) + } + state = %{state | sfp_state: sfp_state} {logits, state} From 544d80f8e4cbf9db1096f3935b10385f946dad2f Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Mon, 27 Oct 2025 14:01:13 +0100 Subject: [PATCH 19/36] vectorized sequences are called sequence in context --- lib/bumblebee/text/generation.ex | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index 6156bd46..4733573f 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -682,7 +682,7 @@ defmodule Bumblebee.Text.Generation do logits = Nx.vectorize(logits, :batch) context = %{ - sequences: Nx.vectorize(state.sequences, :batch), + sequence: Nx.vectorize(state.sequences, :batch), length: state.length, input_length: state.input_length } From 2ba5e0adb32eaeda42a957a38048363acc21ea57 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Mon, 27 Oct 2025 11:01:46 +0100 Subject: [PATCH 20/36] don't vectorize all the logits processor state --- lib/bumblebee/text/generation.ex | 12 ++++++------ test/bumblebee/text/generation_test.exs | 14 +++++++------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index 4733573f..655c34cb 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -687,15 +687,15 @@ defmodule Bumblebee.Text.Generation do input_length: state.input_length } - {logits, new_logits_processor_state} = - logits_processor_process_fun.(logits, context, Nx.vectorize(state.logits_processor_state, :batch)) + {logits, logits_processor_state} = + logits_processor_process_fun.( + logits, + context, + state.logits_processor_state + ) logits = Nx.devectorize(logits, keep_names: false) - logits_processor_state = - Nx.devectorize(new_logits_processor_state, keep_names: false) - - {logits, %{state | logits_processor_state: logits_processor_state}} end diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index 7397d1ec..61adc93e 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -248,15 +248,15 @@ defmodule Bumblebee.Text.GenerationTest do @impl Bumblebee.LogitsProcessor def process(_logits_processor, state, logits, _context) do - sfp_state = state.sfp_state - logits = enforce_token(logits, sfp_state.next_enforced_token_id) + next_enforced_token_id = Nx.vectorize(state.sfp_state.next_enforced_token_id, :batch) - sfp_state = %{ - sfp_state - | next_enforced_token_id: Nx.add(sfp_state.next_enforced_token_id, 1) - } + logits = enforce_token(logits, next_enforced_token_id) + + next_enforced_token_id = + Nx.add(next_enforced_token_id, 1) + |> Nx.devectorize(keep_names: false) - state = %{state | sfp_state: sfp_state} + state = put_in(state.sfp_state.next_enforced_token_id, next_enforced_token_id) {logits, state} end From 196c8f05e036c3b65f599b5eebf41fd2f0d1a9c3 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Wed, 5 Nov 2025 15:54:02 +0100 Subject: [PATCH 21/36] swap {logits, state} to {state, logits} --- lib/bumblebee/logits_processor.ex | 2 +- lib/bumblebee/text/generation.ex | 2 +- lib/bumblebee/text/generation/stateless_logits_processor.ex | 2 +- test/bumblebee/text/generation_test.exs | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/bumblebee/logits_processor.ex b/lib/bumblebee/logits_processor.ex index 6e9002dd..b2a8b998 100644 --- a/lib/bumblebee/logits_processor.ex +++ b/lib/bumblebee/logits_processor.ex @@ -34,5 +34,5 @@ defmodule Bumblebee.LogitsProcessor do state(), logits :: Nx.Tensor.t(), context :: term() - ) :: {logits :: Nx.Tensor.t(), state :: map()} + ) :: {state :: map(), logits :: Nx.Tensor.t()} end diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index 655c34cb..e6ff1a66 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -687,7 +687,7 @@ defmodule Bumblebee.Text.Generation do input_length: state.input_length } - {logits, logits_processor_state} = + {logits_processor_state, logits} = logits_processor_process_fun.( logits, context, diff --git a/lib/bumblebee/text/generation/stateless_logits_processor.ex b/lib/bumblebee/text/generation/stateless_logits_processor.ex index fb05a85f..e34133fc 100644 --- a/lib/bumblebee/text/generation/stateless_logits_processor.ex +++ b/lib/bumblebee/text/generation/stateless_logits_processor.ex @@ -25,6 +25,6 @@ defmodule Bumblebee.Text.Generation.StatelessLogitsProcessor do @impl Bumblebee.LogitsProcessor def process(logits_processor, state, logits, context) do - {logits_processor.fun.(logits, context), state} + {state, logits_processor.fun.(logits, context)} end end diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index 61adc93e..5857b4b4 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -258,7 +258,7 @@ defmodule Bumblebee.Text.GenerationTest do state = put_in(state.sfp_state.next_enforced_token_id, next_enforced_token_id) - {logits, state} + {state, logits} end defnp enforce_token(logits, token_id) do From ee2a01e0ac135be9b46c19deff30edfe126f9527 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Wed, 5 Nov 2025 15:55:16 +0100 Subject: [PATCH 22/36] rename logits_processor_state to logits_processor_states --- lib/bumblebee/text/generation.ex | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index e6ff1a66..27004556 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -598,7 +598,7 @@ defmodule Bumblebee.Text.Generation do finished_length: finished_length, # The ignored return value that we attach all hooks to ignored: Nx.broadcast(0, {batch_size}), - logits_processor_state: logits_processor_init_fun.(context) + logits_processor_states: logits_processor_init_fun.(context) } end @@ -687,16 +687,16 @@ defmodule Bumblebee.Text.Generation do input_length: state.input_length } - {logits_processor_state, logits} = + {logits_processor_states, logits} = logits_processor_process_fun.( logits, context, - state.logits_processor_state + state.logits_processor_states ) logits = Nx.devectorize(logits, keep_names: false) - {logits, %{state | logits_processor_state: logits_processor_state}} + {logits, %{state | logits_processor_states: logits_processor_states}} end # Contrastive search From 3563ff0ece4b88a8de7e89a48d9e00ea7b9b0631 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Wed, 5 Nov 2025 15:56:09 +0100 Subject: [PATCH 23/36] states as tuples --- lib/bumblebee/text/generation.ex | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index 27004556..d8807dee 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -400,16 +400,22 @@ defmodule Bumblebee.Text.Generation do end) init_fun = fn context -> - Enum.reduce(processors, %{}, fn processor, state_acc -> - state = Bumblebee.logits_processor_init(processor, context) - Map.merge(state_acc, state) + processors + |> Enum.map(fn processor -> + Bumblebee.logits_processor_init(processor, context) end) + |> List.to_tuple() end - process_fun = fn logits, context, state -> - Enum.reduce(processors, {logits, state}, fn processor, {logits, state} -> - Bumblebee.logits_processor_process(processor, state, logits, context) - end) + process_fun = fn logits, context, processor_states -> + {processor_states, logits} = + processors + |> Enum.zip(Tuple.to_list(processor_states)) + |> Enum.map_reduce(logits, fn {processor, processor_state}, logits -> + Bumblebee.logits_processor_process(processor, processor_state, logits, context) + end) + + {List.to_tuple(processor_states), logits} end {init_fun, process_fun} From 6db771ee197090048fed1311036deb2281f2fbff Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Wed, 5 Nov 2025 15:56:14 +0100 Subject: [PATCH 24/36] update test --- test/bumblebee/text/generation_test.exs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index 5857b4b4..1ad8a124 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -240,15 +240,13 @@ defmodule Bumblebee.Text.GenerationTest do Nx.tensor(initial_enforced_token_ids) %{ - sfp_state: %{ - next_enforced_token_id: initial_enforced_batch_token_id - } + next_enforced_token_id: initial_enforced_batch_token_id } end @impl Bumblebee.LogitsProcessor def process(_logits_processor, state, logits, _context) do - next_enforced_token_id = Nx.vectorize(state.sfp_state.next_enforced_token_id, :batch) + next_enforced_token_id = Nx.vectorize(state.next_enforced_token_id, :batch) logits = enforce_token(logits, next_enforced_token_id) @@ -256,7 +254,7 @@ defmodule Bumblebee.Text.GenerationTest do Nx.add(next_enforced_token_id, 1) |> Nx.devectorize(keep_names: false) - state = put_in(state.sfp_state.next_enforced_token_id, next_enforced_token_id) + state = put_in(state.next_enforced_token_id, next_enforced_token_id) {state, logits} end From c8442e09e8a0cbaf503bc93e2294fc1ecf556f28 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Wed, 5 Nov 2025 16:58:10 +0100 Subject: [PATCH 25/36] single initial state for all batch entries --- test/bumblebee/text/generation_test.exs | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index 1ad8a124..09f7f1b7 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -146,7 +146,7 @@ defmodule Bumblebee.Text.GenerationTest do # ToDo Bumblee.configure() logits_processors: [ Bumblebee.configure(Bumblebee.Text.GenerationTest.StatefulLogitsProcessing, - initial_enforced_token_ids: [79] + initial_enforced_token_id: 79 ) ] ) @@ -181,7 +181,7 @@ defmodule Bumblebee.Text.GenerationTest do Bumblebee.Text.Generation.build_generate(model, spec, generation_config, logits_processors: [ Bumblebee.configure(Bumblebee.Text.GenerationTest.StatefulLogitsProcessing, - initial_enforced_token_ids: [78, 20] + initial_enforced_token_id: 78 ) ] ) @@ -203,10 +203,10 @@ defmodule Bumblebee.Text.GenerationTest do # second entry in batch # first token_id is 20 as we enforce token_id 20 on the first iteration - assert_equal(token_ids[[1, 0]], 20) + assert_equal(token_ids[[1, 0]], 78) # in the next step we increment from 20 to 21 and enforce token_id 21 - assert_equal(token_ids[[1, 1]], 21) + assert_equal(token_ids[[1, 1]], 79) end defmodule StatefulLogitsProcessing do @@ -218,9 +218,9 @@ defmodule Bumblebee.Text.GenerationTest do @behaviour Bumblebee.LogitsProcessor options = [ - initial_enforced_token_ids: [ + initial_enforced_token_id: [ default: [], - doc: "A list of token ids to enforce on the first iteration" + doc: "A token id to enforce on the first iteration" ] ] @@ -232,12 +232,11 @@ defmodule Bumblebee.Text.GenerationTest do end @impl Bumblebee.LogitsProcessor - def init(logits_processor, _context) do - initial_enforced_token_ids = - Enum.map(logits_processor.initial_enforced_token_ids, &List.wrap(&1)) + def init(logits_processor, context) do + batch_size = Nx.axis_size(context.sequences, 0) initial_enforced_batch_token_id = - Nx.tensor(initial_enforced_token_ids) + Nx.broadcast(logits_processor.initial_enforced_token_id, {batch_size, 1}) %{ next_enforced_token_id: initial_enforced_batch_token_id From 41dd2ad4dc4cb7d95b7c9e9f23e24f05e0ff2337 Mon Sep 17 00:00:00 2001 From: Joel Koch Date: Wed, 5 Nov 2025 17:28:31 +0100 Subject: [PATCH 26/36] vectorize sequence for init, derive vectorized state --- lib/bumblebee/text/generation.ex | 2 +- test/bumblebee/text/generation_test.exs | 12 +++++------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index d8807dee..669c1b7e 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -592,7 +592,7 @@ defmodule Bumblebee.Text.Generation do finished_length = Nx.select(padded_batch_item?, 1, 0) context = %{ - sequences: sequences, + sequence: Nx.vectorize(sequences, :batch), input_length: length, length: length } diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index 09f7f1b7..3eff4dba 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -233,10 +233,10 @@ defmodule Bumblebee.Text.GenerationTest do @impl Bumblebee.LogitsProcessor def init(logits_processor, context) do - batch_size = Nx.axis_size(context.sequences, 0) + initial_enforced_token_id = Nx.tensor([logits_processor.initial_enforced_token_id]) - initial_enforced_batch_token_id = - Nx.broadcast(logits_processor.initial_enforced_token_id, {batch_size, 1}) + [initial_enforced_batch_token_id, _sequence] = + Nx.broadcast_vectors([initial_enforced_token_id, context.sequence]) %{ next_enforced_token_id: initial_enforced_batch_token_id @@ -245,13 +245,11 @@ defmodule Bumblebee.Text.GenerationTest do @impl Bumblebee.LogitsProcessor def process(_logits_processor, state, logits, _context) do - next_enforced_token_id = Nx.vectorize(state.next_enforced_token_id, :batch) + next_enforced_token_id = state.next_enforced_token_id logits = enforce_token(logits, next_enforced_token_id) - next_enforced_token_id = - Nx.add(next_enforced_token_id, 1) - |> Nx.devectorize(keep_names: false) + next_enforced_token_id = Nx.add(next_enforced_token_id, 1) state = put_in(state.next_enforced_token_id, next_enforced_token_id) From 311f77cc77dbc7136f067ee8822014f4d1c10a42 Mon Sep 17 00:00:00 2001 From: Chris Date: Fri, 14 Nov 2025 23:26:59 +0100 Subject: [PATCH 27/36] switch to EXLA as the evaluator is lacking vectorisation support: ``` ** (RuntimeError) unexpected vectorized axes in evaluator for operation :add: #Nx.Tensor< vectorized[batch: 1] s32[1] Nx.Defn.Expr tensor a s32[1] b = reshape a s32[1][1] ``` --- test/bumblebee/text/generation_test.exs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index 3eff4dba..571f8487 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -156,7 +156,8 @@ defmodule Bumblebee.Text.GenerationTest do # # Now, with the processor below, we expect the sequence of [79, 80, 81 ..] - %{token_ids: token_ids} = generate.(params, inputs) + %{token_ids: token_ids} = + Nx.Defn.jit_apply(generate, [params, inputs], compiler: EXLA) # first token_id should be 79 as we enforce token_id 79 assert_equal(token_ids[[0, 0]], 79) @@ -186,7 +187,8 @@ defmodule Bumblebee.Text.GenerationTest do ] ) - %{token_ids: token_ids} = generate.(params, inputs) + %{token_ids: token_ids} = + Nx.Defn.jit_apply(generate, [params, inputs], compiler: EXLA) # result without logit processor: 80, 80, 80 From ec922642c6b58d205a6fa23090e3609225cea9f5 Mon Sep 17 00:00:00 2001 From: chris Date: Fri, 14 Nov 2025 23:30:32 +0100 Subject: [PATCH 28/36] Apply suggestion from @jonatanklosko MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jonatan Kłosko --- test/bumblebee/text/generation_test.exs | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index 571f8487..a07bf53a 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -193,22 +193,14 @@ defmodule Bumblebee.Text.GenerationTest do # result without logit processor: 80, 80, 80 # first entry in batch - # first token_id should be 78 as we enforce token_id 78 on the first - # iteration assert_equal(token_ids[[0, 0]], 78) - - # second should be 79 as we increment the enforced token_id from 78 to 79 assert_equal(token_ids[[0, 1]], 79) - - # in the next step we increment from 79 to 80 and enforce token_id 80 assert_equal(token_ids[[0, 2]], 80) # second entry in batch - # first token_id is 20 as we enforce token_id 20 on the first iteration assert_equal(token_ids[[1, 0]], 78) - - # in the next step we increment from 20 to 21 and enforce token_id 21 assert_equal(token_ids[[1, 1]], 79) + assert_equal(token_ids[[1, 2]], 80) end defmodule StatefulLogitsProcessing do From 201e103085187057bec9bfc26006b69d99b945f3 Mon Sep 17 00:00:00 2001 From: Chris Date: Fri, 14 Nov 2025 23:36:14 +0100 Subject: [PATCH 29/36] removed comments --- test/bumblebee/text/generation_test.exs | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index a07bf53a..c8f8b374 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -159,10 +159,7 @@ defmodule Bumblebee.Text.GenerationTest do %{token_ids: token_ids} = Nx.Defn.jit_apply(generate, [params, inputs], compiler: EXLA) - # first token_id should be 79 as we enforce token_id 79 assert_equal(token_ids[[0, 0]], 79) - - # in the next step we increment from 79 to 80 and enforce token_id 80 assert_equal(token_ids[[0, 1]], 80) ######################################################### From 70d7f655c530373a8690ffe9bf2bde32984a9f4b Mon Sep 17 00:00:00 2001 From: Chris Date: Fri, 14 Nov 2025 23:39:33 +0100 Subject: [PATCH 30/36] slimmed down comments more --- test/bumblebee/text/generation_test.exs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index c8f8b374..39866def 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -134,10 +134,6 @@ defmodule Bumblebee.Text.GenerationTest do # given initial ID, then increments the token ID to be enforced on the # following iterations. The ID of the token to be enforced is passed on # between iterations using the logits_processor_state. - # - # So invoked with the initial ID of 79, it enforces 79, 80, 81, ... in - # the subsequent iterations, demonstrating the use of the state in a - # logits processor. generation_config = Bumblebee.configure(generation_config, max_new_tokens: 2) @@ -154,7 +150,8 @@ defmodule Bumblebee.Text.GenerationTest do # The result without the logits processor would be, as with the first # decoder test above, [80, 80, 80]. # - # Now, with the processor below, we expect the sequence of [79, 80, 81 ..] + # Now, with the processor below, we expect the sequence of [79, 80, 81 ..], + # demonstrating the use of the state in a logits processor. %{token_ids: token_ids} = Nx.Defn.jit_apply(generate, [params, inputs], compiler: EXLA) From ce92584980fbd11a5e35be7abd5d14e5d927b090 Mon Sep 17 00:00:00 2001 From: Chris Date: Fri, 14 Nov 2025 23:51:33 +0100 Subject: [PATCH 31/36] introduced types for init_context and process_context --- lib/bumblebee.ex | 4 ++-- lib/bumblebee/logits_processor.ex | 12 ++++++++++-- .../text/generation/stateless_logits_processor.ex | 6 +++--- test/bumblebee/text/generation_test.exs | 6 +++--- 4 files changed, 18 insertions(+), 10 deletions(-) diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index 06ae8d09..62010053 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -1092,7 +1092,7 @@ defmodule Bumblebee do @doc type: :logits_processor @spec logits_processor_init( Bumblebee.LogitsProcessor.t(), - context :: term() + context :: Bumblebee.LogitsProcessor.init_context() ) :: Bumblebee.LogitsProcessor.state() def logits_processor_init(%module{} = logits_processor, context) do module.init(logits_processor, context) @@ -1107,7 +1107,7 @@ defmodule Bumblebee do Bumblebee.LogitsProcessor.t(), Bumblebee.LogitsProcessor.state(), logits :: Nx.Tensor.t(), - context :: term() + context :: Bumblebee.LogitsProcessor.process_context() ) :: {Bumblebee.LogitsProcessor.state(), logits :: Nx.Tensor.t()} def logits_processor_process(%module{} = logits_processor, state, logits, context) do module.process(logits_processor, state, logits, context) diff --git a/lib/bumblebee/logits_processor.ex b/lib/bumblebee/logits_processor.ex index b2a8b998..40c74daf 100644 --- a/lib/bumblebee/logits_processor.ex +++ b/lib/bumblebee/logits_processor.ex @@ -15,6 +15,14 @@ defmodule Bumblebee.LogitsProcessor do @type state :: Nx.Container.t() + @type process_context :: %{ + sequence: Nx.Tensor.t(), + length: Nx.Tensor.t(), + input_length: Nx.Tensor.t() + } + + @type init_context :: %{} + @doc """ Initializes state for a new logits processor. @@ -24,7 +32,7 @@ defmodule Bumblebee.LogitsProcessor do Oftentimes logits processors are stateless, in which case this function can return an empty container, such as `{}`. """ - @callback init(t(), any()) :: state() + @callback init(t(), init_context()) :: state() @doc """ Processes logits, applying specific rules. @@ -33,6 +41,6 @@ defmodule Bumblebee.LogitsProcessor do t(), state(), logits :: Nx.Tensor.t(), - context :: term() + context :: process_context() ) :: {state :: map(), logits :: Nx.Tensor.t()} end diff --git a/lib/bumblebee/text/generation/stateless_logits_processor.ex b/lib/bumblebee/text/generation/stateless_logits_processor.ex index e34133fc..8e84d6fd 100644 --- a/lib/bumblebee/text/generation/stateless_logits_processor.ex +++ b/lib/bumblebee/text/generation/stateless_logits_processor.ex @@ -19,12 +19,12 @@ defmodule Bumblebee.Text.Generation.StatelessLogitsProcessor do end @impl Bumblebee.LogitsProcessor - def init(_logits_processor, _context) do + def init(_logits_processor, _init_context) do %{} end @impl Bumblebee.LogitsProcessor - def process(logits_processor, state, logits, context) do - {state, logits_processor.fun.(logits, context)} + def process(logits_processor, state, logits, process_context) do + {state, logits_processor.fun.(logits, process_context)} end end diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index 39866def..673e6ad1 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -220,11 +220,11 @@ defmodule Bumblebee.Text.GenerationTest do end @impl Bumblebee.LogitsProcessor - def init(logits_processor, context) do + def init(logits_processor, init_context) do initial_enforced_token_id = Nx.tensor([logits_processor.initial_enforced_token_id]) [initial_enforced_batch_token_id, _sequence] = - Nx.broadcast_vectors([initial_enforced_token_id, context.sequence]) + Nx.broadcast_vectors([initial_enforced_token_id, init_context.sequence]) %{ next_enforced_token_id: initial_enforced_batch_token_id @@ -232,7 +232,7 @@ defmodule Bumblebee.Text.GenerationTest do end @impl Bumblebee.LogitsProcessor - def process(_logits_processor, state, logits, _context) do + def process(_logits_processor, state, logits, _process_context) do next_enforced_token_id = state.next_enforced_token_id logits = enforce_token(logits, next_enforced_token_id) From 6d8f494c3b40e3b4dd68e342c23b55ec16609380 Mon Sep 17 00:00:00 2001 From: Chris Date: Sat, 15 Nov 2025 00:10:08 +0100 Subject: [PATCH 32/36] don't vectorize initial_enforced_token_id in test as it's the same over all batches now --- test/bumblebee/text/generation_test.exs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index 673e6ad1..10014b4b 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -220,14 +220,11 @@ defmodule Bumblebee.Text.GenerationTest do end @impl Bumblebee.LogitsProcessor - def init(logits_processor, init_context) do + def init(logits_processor, _init_context) do initial_enforced_token_id = Nx.tensor([logits_processor.initial_enforced_token_id]) - [initial_enforced_batch_token_id, _sequence] = - Nx.broadcast_vectors([initial_enforced_token_id, init_context.sequence]) - %{ - next_enforced_token_id: initial_enforced_batch_token_id + next_enforced_token_id: initial_enforced_token_id } end From 578ce114aa046366b0633f31862aefcd07c67bcf Mon Sep 17 00:00:00 2001 From: Chris Date: Sat, 15 Nov 2025 00:41:52 +0100 Subject: [PATCH 33/36] bonus track: two more livebooks concerning logits processing. Not strictly related to statefull processing --- notebooks/debug_print.livemd | 203 +++++++++++++++++++++++++++++++++ notebooks/suppressing_e.livemd | 160 ++++++++++++++++++++++++++ 2 files changed, 363 insertions(+) create mode 100644 notebooks/debug_print.livemd create mode 100644 notebooks/suppressing_e.livemd diff --git a/notebooks/debug_print.livemd b/notebooks/debug_print.livemd new file mode 100644 index 00000000..227ec7e3 --- /dev/null +++ b/notebooks/debug_print.livemd @@ -0,0 +1,203 @@ +# Printing top logits and token ids + +```elixir +Mix.install([ + {:bumblebee, "~> 0.6"}, + {:nx, "~> 0.10.0", override: true}, + {:exla, "~> 0.10.0"}, + {:emlx, github: "elixir-nx/emlx"} +]) + +# backend = {EMLX.Backend, device: :gpu} +# compiler = EMLX +backend = {EXLA.Backend, client: :host} +compiler = EXLA + +Nx.global_default_backend(backend) +``` + +## A print logits processor + +```elixir +defmodule PrintLogitsProcessor do + import Nx.Defn + + deftransform debug_processor(logits, context, opts \\ []) do + k = opts[:debug_limit] + + print_top_k_logits_and_token_ids(logits, k, context.sequence) + end + + defnp print_top_k_logits_and_token_ids(logits, k, sequence) do + token = create_token() + + {top_values, top_indices} = Nx.top_k(logits, k: k) + + {token, _sequence} = + hook_token(token, sequence, :sequence, &IO.inspect({:sequence, &1}, limit: :infinity)) + + {token, _top_values} = + hook_token(token, top_values, :top_values, &IO.inspect({:logits, &1})) + + {token, _top_indices} = + hook_token(token, top_indices, :top_indices, &IO.inspect({:token_ids, &1})) + + attach_token(token, logits) + end +end +``` + +## Building the generate function + +```elixir +repo = {:hf, "HuggingFaceTB/SmolLM2-135M-Instruct"} + +sequence_length = 512 + +max_new_tokens = 32 + +prompt = """ +<|im_start|>system +You are a helpful AI assistant named SmolLM. You tell phantastic poems about airships.<|im_end|> +<|im_start|>user +Tell about airships.<|im_end|> +<|im_start|>assistant +""" + +{:ok, model_info} = Bumblebee.load_model(repo, backend: backend) + +{:ok, tokenizer} = Bumblebee.load_tokenizer(repo) +{:ok, generation_config} = Bumblebee.load_generation_config(repo) + +generation_config = + Bumblebee.configure(generation_config, + max_new_tokens: max_new_tokens, + strategy: %{type: :multinomial_sampling, top_k: 3 } + ) + +%{model: model, params: params, spec: spec} = model_info + +generate_fun = + Bumblebee.Text.Generation.build_generate(model, spec, generation_config, + logits_processors: [ &PrintLogitsProcessor.debug_processor(&1, &2, [debug_limit: 2])] + ) +``` + +## Setting up the serving + +This is taken from `lib/bumblebee/text/text_generation.ex`. It's basically the `Bumblebee.Text.generation` function that you usually use to create text generation servings (with some minor modifications to simplify it). +We must use the lower level API here to be able to include `PrintLogitsProcessor` in `generate_fun`. + +```elixir +alias Bumblebee.Shared + +batch_keys = Shared.sequence_batch_keys(sequence_length) +batch_size = 1 +defn_options = [compiler: compiler] + +preallocate_params = false + +tokenizer = + Bumblebee.configure(tokenizer, + length: sequence_length, + pad_direction: :left, + return_token_type_ids: false, + return_length: true + ) + +validate_input = fn text -> {:ok, %{text: text, seed: :erlang.system_time()}} end + +serving = + Nx.Serving.new( + fn batch_key, defn_options -> + params = Shared.maybe_preallocate(params, preallocate_params, defn_options) + + scope = {:generate, batch_key} + + generate_fun = + Shared.compile_or_jit(generate_fun, scope, defn_options, true, fn -> + {:sequence_length, sequence_length} = batch_key + + inputs = %{ + "input_ids" => Nx.template({batch_size, sequence_length}, :u32), + "attention_mask" => Nx.template({batch_size, sequence_length}, :u32), + "seed" => Nx.template({batch_size}, :s64) + } + + [params, inputs] + end) + + fn inputs -> + inputs = Shared.maybe_pad(inputs, batch_size) + generate_fun.(params, inputs) |> Shared.serving_post_computation() + end + end, + defn_options + ) + |> Nx.Serving.batch_size(batch_size) + |> Nx.Serving.process_options(batch_keys: batch_keys) + |> Nx.Serving.client_preprocessing(fn input -> + {inputs, multi?} = Shared.validate_serving_input!(input, &validate_input.(&1)) + + texts = Enum.map(inputs, & &1.text) + seed = Enum.map(inputs, & &1.seed) |> Nx.tensor(type: :s64, backend: Nx.BinaryBackend) + + inputs = + Nx.with_default_backend(Nx.BinaryBackend, fn -> + Bumblebee.apply_tokenizer(tokenizer, texts) + end) + + {input_length, inputs} = Map.pop!(inputs, "length") + input_padded_length = Nx.axis_size(inputs["input_ids"], 1) + + inputs = Map.put(inputs, "seed", seed) + + batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length) + batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key) + + {batch, {multi?, input_length, input_padded_length}} + end) + |> Nx.Serving.client_postprocessing(fn {%{token_ids: token_ids, length: length}, _metadata}, + {multi?, input_length, input_padded_length} -> + decoded = Bumblebee.Tokenizer.decode(tokenizer, token_ids) + output_length = Nx.to_flat_list(length) + input_length = Nx.to_flat_list(input_length) + + Enum.zip_with( + [decoded, output_length, input_length], + fn [decoded, output_length, input_length] -> + token_summary = + %{ + input: input_length, + output: output_length, + padding: input_padded_length - input_length + } + + %{results: [%{text: decoded, token_summary: token_summary}]} + end + ) + |> Shared.normalize_output(multi?) + end) +``` + +## Run the serving + +```elixir +prompt = """ +Tell me about airships. +""" +``` + +### Note: + +In the following cell, the **content of :sequence is padded** ([2, 2, ...] scroll to the right to see the content emerge): + +``` + [2, 2, 2, .. (a lot of 2's later) ... 31530, 549, 563, 1512, 27322, 30, 198, ...] +``` + +Have a look at the [tokenizer.json file on hugging face](https://huggingface.co/HuggingFaceTB/SmolLM2-135M-Instruct/blob/main/tokenizer.json) to see the meaning on the tokens. + +```elixir +Nx.Serving.run(serving, prompt) +``` diff --git a/notebooks/suppressing_e.livemd b/notebooks/suppressing_e.livemd new file mode 100644 index 00000000..36500212 --- /dev/null +++ b/notebooks/suppressing_e.livemd @@ -0,0 +1,160 @@ +# Suppressing e + +```elixir +Mix.install([ + {:bumblebee, "~> 0.6.0"}, + {:nx, "~> 0.10.0"}, + {:exla, "~> 0.10.0"}, + {:kino, "~> 0.17.0"}, + {:emlx, "~> 0.2.0"} +]) + +# EMLX is fast but seems to work only with greedy strategy +# backend = {EMLX.Backend, device: :gpu} +# compiler = EMLX + +backend = {EXLA.Backend, client: :host} +compiler = EXLA + +Nx.global_default_backend(backend) +``` + +## Introduction + +In this notebook we outline the general setup for running a Large Language Model (LLM). + +## SmolLM2 + +In this section we look at running the [SmolLM2](https://huggingface.co/HuggingFaceTB/SmolLM2-135M-Instruct) model from huggingface as it is a small and open source LLM. + + + +Let's load the model and create a serving for text generation: + +```elixir +repo = {:hf, "HuggingFaceTB/SmolLM2-1.7B-Instruct"} + +{:ok, model_info} = Bumblebee.load_model(repo, type: :bf16, backend: backend) +{:ok, tokenizer} = Bumblebee.load_tokenizer(repo) +{:ok, generation_config} = Bumblebee.load_generation_config(repo) + +:ok +``` + +```elixir +generation_config = + Bumblebee.configure(generation_config, + max_new_tokens: 60, + # note that multinomial sampling might still pick one of the suppressed tokens + # depending on top_p or top_k + # strategy: %{type: :greedy_search} + strategy: %{type: :multinomial_sampling, top_k: 10} + # strategy: %{type: :multinomial_sampling, top_p: 0.7} + ) + +serving = + Bumblebee.Text.generation(model_info, tokenizer, generation_config, + compile: [batch_size: 1, sequence_length: 256], + stream: false, + defn_options: [compiler: compiler] + ) +``` + +```elixir +prompt = """ +<|im_start|>system +You are an AI Shakespeare writing poems. You are not allowed to use the letter e. +Kindoms will fall if you do. +Do NOT use the letter e. +If you use the letter e it will have catastrophic consequences!<|im_end|> +<|im_start|>user +Write a poem praising the functional programming concept.<|im_end|> +<|im_start|>assistant +""" + +Kino.Text.new(prompt) +``` + +```elixir +%{results: [%{text: out}]} = Nx.Serving.run(serving, prompt) + +Kino.Text.new(out) +``` + +```elixir +String.graphemes(out) |> Enum.count(&(&1 == "e" or &1 == "E")) +``` + +## Constrained Sampling + +First, we find all tokens in the vocabulary of our tokenizer which contain the letter `e`. + +```elixir +alias Bumblebee.Tokenizer + +last_token_id = model_info.spec.vocab_size - 1 + +special_tokens_ids = + Tokenizer.all_special_tokens(tokenizer) |> Enum.map(&Tokenizer.token_to_id(tokenizer, &1)) + +allowed_tokens_ids = + special_tokens_ids ++ Enum.map([""], &Tokenizer.token_to_id(tokenizer, &1)) + +token_ids_with_e = + for token_id <- 17..last_token_id, + token_id not in allowed_tokens_ids, + token = Tokenizer.id_to_token(tokenizer, token_id), + String.contains?(token, "e") or String.contains?(token, "E") do + token_id + end + +Enum.map(token_ids_with_e, fn id -> {id, Tokenizer.id_to_token(tokenizer, id)} end) +``` + +Then, we suppress all token ids that correspond to a token containing the letter `e` during generation. + +This is the logits processor used when we pass the config as below. + + + +```elixir + deftransform suppressed_tokens_processor(logits, _context, opts \\ []) do + opts = Keyword.validate!(opts, [:suppressed_token_ids]) + + indices = opts[:suppressed_token_ids] |> Nx.tensor() |> Nx.new_axis(-1) + values = Nx.broadcast(Nx.Constants.neg_infinity(Nx.type(logits)), {Nx.size(indices)}) + Nx.indexed_put(logits, indices, values) + end +``` + +```elixir +generation_config = + Bumblebee.configure(generation_config, + suppressed_token_ids: token_ids_with_e + ) + +serving = + Bumblebee.Text.generation(model_info, tokenizer, generation_config, + compile: [batch_size: 1, sequence_length: 1024], + stream: false, + defn_options: [compiler: compiler] + ) + +%{results: [%{text: out}]} = Nx.Serving.run(serving, prompt) + +Kino.Text.new(out) +``` + +```elixir +String.contains?(out, "e") or String.contains?(out, "E") +``` + +```elixir +%{"input_ids" => out_token_ids} = Bumblebee.apply_tokenizer(tokenizer, out) + +out_token_ids = Nx.to_flat_list(out_token_ids) + +ids = Enum.filter(out_token_ids, &(&1 in token_ids_with_e)) + +Enum.map(ids, &Tokenizer.id_to_token(tokenizer, &1)) +``` From 1f7798f2b499a3f97a8ae71963baafe9a2d08d6b Mon Sep 17 00:00:00 2001 From: chris Date: Mon, 17 Nov 2025 18:28:00 +0100 Subject: [PATCH 34/36] Update test/bumblebee/text/generation/logits_processing_test.exs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jonatan Kłosko --- test/bumblebee/text/generation/logits_processing_test.exs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/bumblebee/text/generation/logits_processing_test.exs b/test/bumblebee/text/generation/logits_processing_test.exs index 190e97c4..5bc5a44f 100644 --- a/test/bumblebee/text/generation/logits_processing_test.exs +++ b/test/bumblebee/text/generation/logits_processing_test.exs @@ -382,8 +382,7 @@ defmodule Bumblebee.Text.Generation.LogitsProcessingTest do %{ sequence: Nx.tensor(sequence), length: Enum.count(sequence, &(&1 != 0)), - input_length: 1, - logits_processor_state: %{} + input_length: 1 } end end From e9d0c780e835657daa129c6eb421e7bcf74060bd Mon Sep 17 00:00:00 2001 From: Chris Date: Mon, 17 Nov 2025 18:34:37 +0100 Subject: [PATCH 35/36] moving livebooks to separate PR --- notebooks/debug_print.livemd | 203 --------------------------------- notebooks/suppressing_e.livemd | 160 -------------------------- 2 files changed, 363 deletions(-) delete mode 100644 notebooks/debug_print.livemd delete mode 100644 notebooks/suppressing_e.livemd diff --git a/notebooks/debug_print.livemd b/notebooks/debug_print.livemd deleted file mode 100644 index 227ec7e3..00000000 --- a/notebooks/debug_print.livemd +++ /dev/null @@ -1,203 +0,0 @@ -# Printing top logits and token ids - -```elixir -Mix.install([ - {:bumblebee, "~> 0.6"}, - {:nx, "~> 0.10.0", override: true}, - {:exla, "~> 0.10.0"}, - {:emlx, github: "elixir-nx/emlx"} -]) - -# backend = {EMLX.Backend, device: :gpu} -# compiler = EMLX -backend = {EXLA.Backend, client: :host} -compiler = EXLA - -Nx.global_default_backend(backend) -``` - -## A print logits processor - -```elixir -defmodule PrintLogitsProcessor do - import Nx.Defn - - deftransform debug_processor(logits, context, opts \\ []) do - k = opts[:debug_limit] - - print_top_k_logits_and_token_ids(logits, k, context.sequence) - end - - defnp print_top_k_logits_and_token_ids(logits, k, sequence) do - token = create_token() - - {top_values, top_indices} = Nx.top_k(logits, k: k) - - {token, _sequence} = - hook_token(token, sequence, :sequence, &IO.inspect({:sequence, &1}, limit: :infinity)) - - {token, _top_values} = - hook_token(token, top_values, :top_values, &IO.inspect({:logits, &1})) - - {token, _top_indices} = - hook_token(token, top_indices, :top_indices, &IO.inspect({:token_ids, &1})) - - attach_token(token, logits) - end -end -``` - -## Building the generate function - -```elixir -repo = {:hf, "HuggingFaceTB/SmolLM2-135M-Instruct"} - -sequence_length = 512 - -max_new_tokens = 32 - -prompt = """ -<|im_start|>system -You are a helpful AI assistant named SmolLM. You tell phantastic poems about airships.<|im_end|> -<|im_start|>user -Tell about airships.<|im_end|> -<|im_start|>assistant -""" - -{:ok, model_info} = Bumblebee.load_model(repo, backend: backend) - -{:ok, tokenizer} = Bumblebee.load_tokenizer(repo) -{:ok, generation_config} = Bumblebee.load_generation_config(repo) - -generation_config = - Bumblebee.configure(generation_config, - max_new_tokens: max_new_tokens, - strategy: %{type: :multinomial_sampling, top_k: 3 } - ) - -%{model: model, params: params, spec: spec} = model_info - -generate_fun = - Bumblebee.Text.Generation.build_generate(model, spec, generation_config, - logits_processors: [ &PrintLogitsProcessor.debug_processor(&1, &2, [debug_limit: 2])] - ) -``` - -## Setting up the serving - -This is taken from `lib/bumblebee/text/text_generation.ex`. It's basically the `Bumblebee.Text.generation` function that you usually use to create text generation servings (with some minor modifications to simplify it). -We must use the lower level API here to be able to include `PrintLogitsProcessor` in `generate_fun`. - -```elixir -alias Bumblebee.Shared - -batch_keys = Shared.sequence_batch_keys(sequence_length) -batch_size = 1 -defn_options = [compiler: compiler] - -preallocate_params = false - -tokenizer = - Bumblebee.configure(tokenizer, - length: sequence_length, - pad_direction: :left, - return_token_type_ids: false, - return_length: true - ) - -validate_input = fn text -> {:ok, %{text: text, seed: :erlang.system_time()}} end - -serving = - Nx.Serving.new( - fn batch_key, defn_options -> - params = Shared.maybe_preallocate(params, preallocate_params, defn_options) - - scope = {:generate, batch_key} - - generate_fun = - Shared.compile_or_jit(generate_fun, scope, defn_options, true, fn -> - {:sequence_length, sequence_length} = batch_key - - inputs = %{ - "input_ids" => Nx.template({batch_size, sequence_length}, :u32), - "attention_mask" => Nx.template({batch_size, sequence_length}, :u32), - "seed" => Nx.template({batch_size}, :s64) - } - - [params, inputs] - end) - - fn inputs -> - inputs = Shared.maybe_pad(inputs, batch_size) - generate_fun.(params, inputs) |> Shared.serving_post_computation() - end - end, - defn_options - ) - |> Nx.Serving.batch_size(batch_size) - |> Nx.Serving.process_options(batch_keys: batch_keys) - |> Nx.Serving.client_preprocessing(fn input -> - {inputs, multi?} = Shared.validate_serving_input!(input, &validate_input.(&1)) - - texts = Enum.map(inputs, & &1.text) - seed = Enum.map(inputs, & &1.seed) |> Nx.tensor(type: :s64, backend: Nx.BinaryBackend) - - inputs = - Nx.with_default_backend(Nx.BinaryBackend, fn -> - Bumblebee.apply_tokenizer(tokenizer, texts) - end) - - {input_length, inputs} = Map.pop!(inputs, "length") - input_padded_length = Nx.axis_size(inputs["input_ids"], 1) - - inputs = Map.put(inputs, "seed", seed) - - batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length) - batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key) - - {batch, {multi?, input_length, input_padded_length}} - end) - |> Nx.Serving.client_postprocessing(fn {%{token_ids: token_ids, length: length}, _metadata}, - {multi?, input_length, input_padded_length} -> - decoded = Bumblebee.Tokenizer.decode(tokenizer, token_ids) - output_length = Nx.to_flat_list(length) - input_length = Nx.to_flat_list(input_length) - - Enum.zip_with( - [decoded, output_length, input_length], - fn [decoded, output_length, input_length] -> - token_summary = - %{ - input: input_length, - output: output_length, - padding: input_padded_length - input_length - } - - %{results: [%{text: decoded, token_summary: token_summary}]} - end - ) - |> Shared.normalize_output(multi?) - end) -``` - -## Run the serving - -```elixir -prompt = """ -Tell me about airships. -""" -``` - -### Note: - -In the following cell, the **content of :sequence is padded** ([2, 2, ...] scroll to the right to see the content emerge): - -``` - [2, 2, 2, .. (a lot of 2's later) ... 31530, 549, 563, 1512, 27322, 30, 198, ...] -``` - -Have a look at the [tokenizer.json file on hugging face](https://huggingface.co/HuggingFaceTB/SmolLM2-135M-Instruct/blob/main/tokenizer.json) to see the meaning on the tokens. - -```elixir -Nx.Serving.run(serving, prompt) -``` diff --git a/notebooks/suppressing_e.livemd b/notebooks/suppressing_e.livemd deleted file mode 100644 index 36500212..00000000 --- a/notebooks/suppressing_e.livemd +++ /dev/null @@ -1,160 +0,0 @@ -# Suppressing e - -```elixir -Mix.install([ - {:bumblebee, "~> 0.6.0"}, - {:nx, "~> 0.10.0"}, - {:exla, "~> 0.10.0"}, - {:kino, "~> 0.17.0"}, - {:emlx, "~> 0.2.0"} -]) - -# EMLX is fast but seems to work only with greedy strategy -# backend = {EMLX.Backend, device: :gpu} -# compiler = EMLX - -backend = {EXLA.Backend, client: :host} -compiler = EXLA - -Nx.global_default_backend(backend) -``` - -## Introduction - -In this notebook we outline the general setup for running a Large Language Model (LLM). - -## SmolLM2 - -In this section we look at running the [SmolLM2](https://huggingface.co/HuggingFaceTB/SmolLM2-135M-Instruct) model from huggingface as it is a small and open source LLM. - - - -Let's load the model and create a serving for text generation: - -```elixir -repo = {:hf, "HuggingFaceTB/SmolLM2-1.7B-Instruct"} - -{:ok, model_info} = Bumblebee.load_model(repo, type: :bf16, backend: backend) -{:ok, tokenizer} = Bumblebee.load_tokenizer(repo) -{:ok, generation_config} = Bumblebee.load_generation_config(repo) - -:ok -``` - -```elixir -generation_config = - Bumblebee.configure(generation_config, - max_new_tokens: 60, - # note that multinomial sampling might still pick one of the suppressed tokens - # depending on top_p or top_k - # strategy: %{type: :greedy_search} - strategy: %{type: :multinomial_sampling, top_k: 10} - # strategy: %{type: :multinomial_sampling, top_p: 0.7} - ) - -serving = - Bumblebee.Text.generation(model_info, tokenizer, generation_config, - compile: [batch_size: 1, sequence_length: 256], - stream: false, - defn_options: [compiler: compiler] - ) -``` - -```elixir -prompt = """ -<|im_start|>system -You are an AI Shakespeare writing poems. You are not allowed to use the letter e. -Kindoms will fall if you do. -Do NOT use the letter e. -If you use the letter e it will have catastrophic consequences!<|im_end|> -<|im_start|>user -Write a poem praising the functional programming concept.<|im_end|> -<|im_start|>assistant -""" - -Kino.Text.new(prompt) -``` - -```elixir -%{results: [%{text: out}]} = Nx.Serving.run(serving, prompt) - -Kino.Text.new(out) -``` - -```elixir -String.graphemes(out) |> Enum.count(&(&1 == "e" or &1 == "E")) -``` - -## Constrained Sampling - -First, we find all tokens in the vocabulary of our tokenizer which contain the letter `e`. - -```elixir -alias Bumblebee.Tokenizer - -last_token_id = model_info.spec.vocab_size - 1 - -special_tokens_ids = - Tokenizer.all_special_tokens(tokenizer) |> Enum.map(&Tokenizer.token_to_id(tokenizer, &1)) - -allowed_tokens_ids = - special_tokens_ids ++ Enum.map([""], &Tokenizer.token_to_id(tokenizer, &1)) - -token_ids_with_e = - for token_id <- 17..last_token_id, - token_id not in allowed_tokens_ids, - token = Tokenizer.id_to_token(tokenizer, token_id), - String.contains?(token, "e") or String.contains?(token, "E") do - token_id - end - -Enum.map(token_ids_with_e, fn id -> {id, Tokenizer.id_to_token(tokenizer, id)} end) -``` - -Then, we suppress all token ids that correspond to a token containing the letter `e` during generation. - -This is the logits processor used when we pass the config as below. - - - -```elixir - deftransform suppressed_tokens_processor(logits, _context, opts \\ []) do - opts = Keyword.validate!(opts, [:suppressed_token_ids]) - - indices = opts[:suppressed_token_ids] |> Nx.tensor() |> Nx.new_axis(-1) - values = Nx.broadcast(Nx.Constants.neg_infinity(Nx.type(logits)), {Nx.size(indices)}) - Nx.indexed_put(logits, indices, values) - end -``` - -```elixir -generation_config = - Bumblebee.configure(generation_config, - suppressed_token_ids: token_ids_with_e - ) - -serving = - Bumblebee.Text.generation(model_info, tokenizer, generation_config, - compile: [batch_size: 1, sequence_length: 1024], - stream: false, - defn_options: [compiler: compiler] - ) - -%{results: [%{text: out}]} = Nx.Serving.run(serving, prompt) - -Kino.Text.new(out) -``` - -```elixir -String.contains?(out, "e") or String.contains?(out, "E") -``` - -```elixir -%{"input_ids" => out_token_ids} = Bumblebee.apply_tokenizer(tokenizer, out) - -out_token_ids = Nx.to_flat_list(out_token_ids) - -ids = Enum.filter(out_token_ids, &(&1 in token_ids_with_e)) - -Enum.map(ids, &Tokenizer.id_to_token(tokenizer, &1)) -``` From 1da730e83e4511528ae97ab9da70da97e457789e Mon Sep 17 00:00:00 2001 From: chris Date: Tue, 18 Nov 2025 07:56:00 +0100 Subject: [PATCH 36/36] logits_processor.ex aktualisieren MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jonatan Kłosko --- lib/bumblebee/logits_processor.ex | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/bumblebee/logits_processor.ex b/lib/bumblebee/logits_processor.ex index 40c74daf..21df999d 100644 --- a/lib/bumblebee/logits_processor.ex +++ b/lib/bumblebee/logits_processor.ex @@ -16,10 +16,10 @@ defmodule Bumblebee.LogitsProcessor do @type state :: Nx.Container.t() @type process_context :: %{ - sequence: Nx.Tensor.t(), - length: Nx.Tensor.t(), - input_length: Nx.Tensor.t() - } + sequence: Nx.Tensor.t(), + length: Nx.Tensor.t(), + input_length: Nx.Tensor.t() + } @type init_context :: %{}