Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
fc0825a
[#SAMPLE-6] Add state to logits processing
joelpaulkoch Oct 17, 2025
01ab3af
stateful logits processors
joelpaulkoch Oct 16, 2025
5413662
adding another test
joelpaulkoch Oct 16, 2025
9d4ef39
fix test so compilation works
joelpaulkoch Oct 20, 2025
4ce01cc
demonstrate stateful logits processor through test assertions
joelpaulkoch Oct 20, 2025
2161b77
independent state for batch entries
joelpaulkoch Oct 20, 2025
fefc9fd
renamed initial_suppressed_token_index for clarity
xhr15 Oct 21, 2025
6e8612a
renamend next_suppressed_index -> :next_suppressed_token_index
xhr15 Oct 21, 2025
e43254a
logits_processor_states -> logits_processor_state in batch tests
xhr15 Oct 21, 2025
a2f0015
added a test with batch size 1 for clarity
xhr15 Oct 21, 2025
0cdc0ad
renaming suppressed_id -> suppressed_token_id
xhr15 Oct 21, 2025
cc6d6e3
more comments
xhr15 Oct 21, 2025
3816e7c
changed to to demonstrate stack functionality
xhr15 Oct 23, 2025
fe58712
merged tests
xhr15 Oct 23, 2025
c97890a
removed test for processor only used in test
xhr15 Oct 23, 2025
fbf5ef3
introduces LogitsProcessor module
xhr15 Oct 24, 2025
dfa223c
clean up
joelpaulkoch Oct 27, 2025
9098bda
mix format
joelpaulkoch Oct 27, 2025
544d80f
vectorized sequences are called sequence in context
joelpaulkoch Oct 27, 2025
2ba5e0a
don't vectorize all the logits processor state
joelpaulkoch Oct 27, 2025
196c8f0
swap {logits, state} to {state, logits}
joelpaulkoch Nov 5, 2025
ee2a01e
rename logits_processor_state to logits_processor_states
joelpaulkoch Nov 5, 2025
3563ff0
states as tuples
joelpaulkoch Nov 5, 2025
6db771e
update test
joelpaulkoch Nov 5, 2025
c8442e0
single initial state for all batch entries
joelpaulkoch Nov 5, 2025
41dd2ad
vectorize sequence for init, derive vectorized state
joelpaulkoch Nov 5, 2025
311f77c
switch to EXLA as the evaluator is lacking vectorisation support:
xhr15 Nov 14, 2025
ec92264
Apply suggestion from @jonatanklosko
xhr15 Nov 14, 2025
201e103
removed comments
xhr15 Nov 14, 2025
70d7f65
slimmed down comments more
xhr15 Nov 14, 2025
ce92584
introduced types for init_context and process_context
xhr15 Nov 14, 2025
6d8f494
don't vectorize initial_enforced_token_id in test as it's the same ov…
xhr15 Nov 14, 2025
578ce11
bonus track: two more livebooks concerning logits processing. Not str…
xhr15 Nov 14, 2025
1f7798f
Update test/bumblebee/text/generation/logits_processing_test.exs
xhr15 Nov 17, 2025
e9d0c78
moving livebooks to separate PR
xhr15 Nov 17, 2025
1da730e
logits_processor.ex aktualisieren
xhr15 Nov 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 30 additions & 15 deletions lib/bumblebee/text/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,12 @@ defmodule Bumblebee.Text.Generation do
end ++ logits_processors

fn logits, context ->
for processor <- processors, processor, reduce: logits do
logits -> processor.(logits, context)
for processor <- processors, processor, reduce: {logits, context} do
{logits, context} ->
case processor.(logits, context) do
{logits, new_context} -> {logits, new_context}
logits -> {logits, context}
end
end
end
end
Expand Down Expand Up @@ -551,7 +555,8 @@ defmodule Bumblebee.Text.Generation do
length: length,
finished_length: finished_length,
# The ignored return value that we attach all hooks to
ignored: Nx.broadcast(0, {batch_size})
ignored: Nx.broadcast(0, {batch_size}),
logits_processor_states: %{}
}
end

Expand All @@ -574,7 +579,7 @@ defmodule Bumblebee.Text.Generation do
outputs = predict_fun.(params, inputs)

logits = outputs.logits[[.., -1]]
logits = batch_process_logits(logits_processor_fun, logits, state)
{logits, state} = batch_process_logits(logits_processor_fun, logits, state)
token_id = Nx.argmax(logits, axis: -1)

state = update_sequences(state, token_id, pad_token_id, eos_token_id)
Expand Down Expand Up @@ -632,14 +637,24 @@ defmodule Bumblebee.Text.Generation do
end

defnp batch_process_logits(logits_processor_fun, logits, state) do
logits
|> Nx.vectorize(:batch)
|> logits_processor_fun.(%{
sequence: Nx.vectorize(state.sequences, :batch),
length: state.length,
input_length: state.input_length
})
|> Nx.devectorize(keep_names: false)
logits = Nx.vectorize(logits, :batch)

{logits, new_context} =
logits_processor_fun.(logits, %{
sequence: Nx.vectorize(state.sequences, :batch),
length: state.length,
input_length: state.input_length,
logits_processor_state: Nx.vectorize(state.logits_processor_states, :batch)
})

logits = Nx.devectorize(logits, keep_names: false)

logits_processor_states =
Nx.devectorize(new_context.logits_processor_state, keep_names: false)

sequences = Nx.devectorize(new_context.sequence, keep_names: false)

{logits, %{state | sequences: sequences, logits_processor_states: logits_processor_states}}
end

# Contrastive search
Expand Down Expand Up @@ -684,7 +699,7 @@ defmodule Bumblebee.Text.Generation do
joint_hidden_state = Nx.put_slice(joint_hidden_state, [0, 0, 0], initial_hidden_state)

logits = outputs.logits[[.., -1]]
logits = batch_process_logits(logits_processor_fun, logits, state)
{logits, state} = batch_process_logits(logits_processor_fun, logits, state)
scores = Axon.Activations.softmax(logits, axis: -1)
{top_k_scores, top_k_token_ids} = Nx.top_k(scores, k: top_k)

Expand Down Expand Up @@ -727,7 +742,7 @@ defmodule Bumblebee.Text.Generation do

logits = outputs.logits[[.., -1]]
logits = Utils.Nx.chunked_take(logits, top_k, selected_idx)
logits = batch_process_logits(logits_processor_fun, logits, state)
{logits, state} = batch_process_logits(logits_processor_fun, logits, state)

scores = Axon.Activations.softmax(logits, axis: -1)
{top_k_scores, top_k_token_ids} = Nx.top_k(scores, k: top_k)
Expand Down Expand Up @@ -888,7 +903,7 @@ defmodule Bumblebee.Text.Generation do
outputs = predict_fun.(params, inputs)

logits = outputs.logits[[.., -1]]
logits = batch_process_logits(logits_processor_fun, logits, state)
{logits, state} = batch_process_logits(logits_processor_fun, logits, state)
scores = Axon.Activations.softmax(logits)
token_id = batched_choice(key, scores)

Expand Down
50 changes: 49 additions & 1 deletion test/bumblebee/text/generation/logits_processing_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,53 @@ defmodule Bumblebee.Text.Generation.LogitsProcessingTest do

alias Bumblebee.Text.Generation.LogitsProcessing

describe "stateful logits processors" do
defmodule StatefulLogitsProcessing do
import Nx.Defn

deftransform stateful_processor(logits, context, opts) do
initial_suppressed_token_index = Nx.tensor([opts[:initial_suppressed_token_index]])

suppressed_index =
context.logits_processor_state[:next_suppressed_token_index] || initial_suppressed_token_index

values =
Nx.broadcast(Nx.Constants.neg_infinity(Nx.type(logits)), Nx.size(suppressed_index))

logits = Nx.indexed_put(logits, suppressed_index, values)

next_suppressed_token_index = Nx.add(suppressed_index, Nx.tensor([1]))

context =
put_in(
context,
[:logits_processor_state, :next_suppressed_token_index],
next_suppressed_token_index
)

{logits, context}
end
end

test "can register and modify state" do
logits = Nx.tensor([1.0, 2.0, 3.0, 4.0])

context = context([1, 0, 0, 0])

{logits, context} =
StatefulLogitsProcessing.stateful_processor(logits, context, initial_suppressed_token_index: 0)

assert_equal(logits, Nx.tensor([:neg_infinity, 2.0, 3.0, 4.0]))
assert_equal(context.logits_processor_state.next_suppressed_token_index, Nx.tensor([1]))

{logits, context} =
StatefulLogitsProcessing.stateful_processor(logits, context, initial_suppressed_token_index: 0)

assert_equal(logits, Nx.tensor([:neg_infinity, :neg_infinity, 3.0, 4.0]))
assert_equal(context.logits_processor_state.next_suppressed_token_index, Nx.tensor([2]))
end
end

describe "suppressed_tokens_processor/3" do
test "ignores the given tokens" do
logits = Nx.tensor([1.0, 2.0, 3.0, 4.0])
Expand Down Expand Up @@ -382,7 +429,8 @@ defmodule Bumblebee.Text.Generation.LogitsProcessingTest do
%{
sequence: Nx.tensor(sequence),
length: Enum.count(sequence, &(&1 != 0)),
input_length: 1
input_length: 1,
logits_processor_state: %{}
}
end
end
145 changes: 145 additions & 0 deletions test/bumblebee/text/generation_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,149 @@ defmodule Bumblebee.Text.GenerationTest do

assert_equal(token_ids, Nx.tensor([[80, 1023, 1023]]))
end


test "with stateful logits processor with batch size of 1" do
assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"})

{:ok, generation_config} =
Bumblebee.load_generation_config({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"})

assert %Bumblebee.Text.Gpt2{architecture: :for_causal_language_modeling} = spec

input_ids = Nx.tensor([[0, 0, 10, 20, 30, 40, 50, 60, 70, 80]])
attention_mask = Nx.tensor([[0, 0, 1, 1, 1, 1, 1, 1, 1, 1]])
seed = Nx.tensor([0])

inputs = %{
"input_ids" => input_ids,
"attention_mask" => attention_mask,
"seed" => seed
}

# We demonstrate the use of the state with the following example of a
# stateful processor (see below). On the first iteration, it suppresses the
# given initial ID, then increments the token ID to be suppressed on the
# following iterations. The ID of the token to be suppressed is passed on
# between iterations using the logits_processor_state.
#
# So invoked with the initial ID of 79, it suppresses 79, 80, 81, ... in
# the subsequent iterations, demonstrating the use of the state in a
# logits processor.

generation_config = Bumblebee.configure(generation_config, max_new_tokens: 2)

generate =
Bumblebee.Text.Generation.build_generate(model, spec, generation_config,
logits_processors: [
&Bumblebee.Text.GenerationTest.StatefulLogitsProcessing.stateful_processor(&1, &2,
initial_suppressed_token_id: [79]
)
]
)

# The result without the logits processor would be, as with the first
# decoder test above: 80, 80, 80.
#
# Now, with the processor below, we expect no change (suppressed token ID is
# 79), then a change to another random token ID (176) as the suppressed
# token ID is incremented from 79 to 80, disallowing the previous most
# likely token ID (80) from being selected.

%{token_ids: token_ids} = generate.(params, inputs)


# first token_id still 80 as we suppress token_id 79
assert_equal(token_ids[[0,0]], 80)
# in the next step we increment from 79 to 80 and suppress token_id 80, the
#result is 176 as that is the next likelihood in the logits.

assert_equal(token_ids[[0,1]], 176)
end

test "with stateful logits processor with batch size of 2" do
assert {:ok, %{model: model, params: params, spec: spec}} =
Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"})

{:ok, generation_config} =
Bumblebee.load_generation_config({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"})

assert %Bumblebee.Text.Gpt2{architecture: :for_causal_language_modeling} = spec

input_ids = Nx.tensor([[0, 0, 10, 20, 30, 40, 50, 60, 70, 80]])
attention_mask = Nx.tensor([[0, 0, 1, 1, 1, 1, 1, 1, 1, 1]])
seed = Nx.tensor([0])

inputs = %{
"input_ids" => Nx.Batch.concatenate([input_ids, input_ids]),
"attention_mask" => Nx.Batch.concatenate([attention_mask, attention_mask]),
"seed" => Nx.Batch.concatenate([seed, seed])
}

# this is the same example as above, but with a batch size of 2.


generation_config = Bumblebee.configure(generation_config, max_new_tokens: 3)

generate =
Bumblebee.Text.Generation.build_generate(model, spec, generation_config,
logits_processors: [
&Bumblebee.Text.GenerationTest.StatefulLogitsProcessing.stateful_processor(&1, &2,
initial_suppressed_token_id: [78, 79]
)
]
)

%{token_ids: token_ids} = generate.(params, inputs)

# result without logit processor: 80, 80, 80

# first entry in batch
# first token_id still 80 as we suppress token_id 78
assert_equal(token_ids[[0, 0]], 80)
# second token_id still 80 as we suppress token_id 79
assert_equal(token_ids[[0, 1]], 80)
# in the next step we increment from 79 to 80 and suppress token_id 80
assert_equal(token_ids[[0, 2]], 1016)

# second entry in batch
# first token_id still 80 as we suppress token_id 79
assert_equal(token_ids[[1, 0]], 80)
# in the next step we increment from 79 to 80 and suppress token_id 80
assert_equal(token_ids[[1, 1]], 176)
end

defmodule StatefulLogitsProcessing do
import Nx.Defn

deftransform stateful_processor(logits, context, opts \\ []) do
initial_suppressed_token_ids = Enum.map(opts[:initial_suppressed_token_id], &List.wrap(&1))
initial_suppressed_token_id = Nx.tensor(initial_suppressed_token_ids) |> Nx.vectorize(:batch)

suppressed_id =
context.logits_processor_state[:next_suppressed_token_id] || initial_suppressed_token_id

logits = suppress_id(logits, suppressed_id)

next_suppressed_token_id = Nx.add(suppressed_id, 1)

context =
put_in(
context,
[:logits_processor_state, :next_suppressed_token_id],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the current API, the state is always initialized to %{} and then first invocation of the processor adds a key, here %{next_suppressed_token_id: %Nx.Tensor{...}}.

This can be problematic in defn while loop, which requires the accumulation sate to always have the same shape. In other words, the initial state should already include :next_suppressed_token_id with the default tensor. It is possible that this didn't come up during your tests, because depending on the model/input, we do the first generation step outside of the while loop, and the first call would initialize the state. However, if we are going to support stateful, I would rather do it in a more robust way.

Given the above, a stateless logits processor would involve two steps (functions):

  1. Building an initial state.
  2. Performing logits processing, which receives logits and state, and returns update logits and state.

This way we can call (1) when initializing the generation context, and for the actual processing we call (2).

The behaviour can be similar to Bumblebee.Scheduler. Something like this:

defmodule Bumblebee.LogitsProcessor do
  @moduledoc """
  An interface for configuring and using logits processors.

  Logits processors are used during autoregressive generation to modify
  predicted scores at each generation step. This allows for applying
  certain rules to the model output to control which tokens are picked
  at each generation step, and which are not.

  Every module implementing this behaviour is expected to also define
  a configuration struct.
  """

  @type t :: Bumblebee.Configurable.t()

  @type state :: Nx.Container.t()

  @doc """
  Initializes state for a new logits processor.

  Returns `state`, which is an opaque `Nx.Container`, and it is then
  passed to and returned from `process/2`.

  Oftentimes logits processors are stateless, in which case this
  function can return an empty continer, such as `{}`.
  """
  @callback init(t(), context) :: state()
            when context: %{
                   prng_key: Nx.Tensor.t()
                 }

  @doc """
  Processes logits, applying specific rules.
  """
  @callback process(
              t(),
              state(),
              logits :: Nx.Tensor.t(),
              context :: context
            ) :: {state :: map(), logits :: Nx.Tensor.t()}
            when context: %{
                   sequence: Nx.Tensor.t(),
                   length: Nx.Tensor.t(),
                   input_length: Nx.Tensor.t()
                 }
end

Technically, the :logits_processors options is public API, but we can make it backward-compatible. For example, we can define %Bumblebee.Text.Generation.StatelessLogitsProcessor{fun: fun}, where the state is always empty and process just invokes the fun. I would even use that for the built-in processors, so that we don't need to define a bunch of new modules.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jonatanklosko Thank you very much for your comments! I think esp. the two step call makes sense. We'll move in that direction :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jonatanklosko
as an afterthought:

What is the use case for context here:

@callback init(t(), context) :: state()
            when context: %{
                   prng_key: Nx.Tensor.t()
                 }

Later in the loop, context holds:

context = %{
      sequences: sequences,
      input_length: length,
      length: length,
    }

I am wondering how those would influence the initialisation of the logits processors?

Or are you planning of using additional keys? E.g. from the state as returned by init squence:

%{
      sequences: sequences,
      input_length: length,
      length: length,
      finished_length: finished_length,
      ignored: Nx.broadcast(0, {batch_size})
    }

If that was the case, we should probably rename the parameter to state or initial_state.

Wdyt?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the use case for context here:

I picked "context" in both functions as a generic name for state/metadata that may be relevant to the logits processor. You can see that in my snippet the context type is different for init and process. Technically all of the context fields could be separate arguments, but keeping it as a map makes the signature more manageable, and more importantly allows us to add more fields in the future without breaking compatibility.

Does that make sense?

next_suppressed_token_id
)

{logits, context}
end

defnp suppress_id(logits, id) do
Nx.indexed_put(
logits,
id,
Nx.Constants.neg_infinity(Nx.type(logits))
)
end
end
end