-
Notifications
You must be signed in to change notification settings - Fork 120
Add state to logits processing #425
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
fc0825a
01ab3af
5413662
9d4ef39
4ce01cc
2161b77
fefc9fd
6e8612a
e43254a
a2f0015
0cdc0ad
cc6d6e3
3816e7c
fe58712
c97890a
fbf5ef3
dfa223c
9098bda
544d80f
2ba5e0a
196c8f0
ee2a01e
3563ff0
6db771e
c8442e0
41dd2ad
311f77c
ec92264
201e103
70d7f65
ce92584
6d8f494
578ce11
1f7798f
e9d0c78
1da730e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -106,4 +106,149 @@ defmodule Bumblebee.Text.GenerationTest do | |
|
|
||
| assert_equal(token_ids, Nx.tensor([[80, 1023, 1023]])) | ||
| end | ||
|
|
||
|
|
||
| test "with stateful logits processor with batch size of 1" do | ||
| assert {:ok, %{model: model, params: params, spec: spec}} = | ||
| Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"}) | ||
|
|
||
| {:ok, generation_config} = | ||
| Bumblebee.load_generation_config({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"}) | ||
|
|
||
| assert %Bumblebee.Text.Gpt2{architecture: :for_causal_language_modeling} = spec | ||
|
|
||
| input_ids = Nx.tensor([[0, 0, 10, 20, 30, 40, 50, 60, 70, 80]]) | ||
| attention_mask = Nx.tensor([[0, 0, 1, 1, 1, 1, 1, 1, 1, 1]]) | ||
| seed = Nx.tensor([0]) | ||
|
|
||
| inputs = %{ | ||
| "input_ids" => input_ids, | ||
| "attention_mask" => attention_mask, | ||
| "seed" => seed | ||
| } | ||
|
|
||
| # We demonstrate the use of the state with the following example of a | ||
| # stateful processor (see below). On the first iteration, it suppresses the | ||
| # given initial ID, then increments the token ID to be suppressed on the | ||
| # following iterations. The ID of the token to be suppressed is passed on | ||
| # between iterations using the logits_processor_state. | ||
| # | ||
| # So invoked with the initial ID of 79, it suppresses 79, 80, 81, ... in | ||
| # the subsequent iterations, demonstrating the use of the state in a | ||
| # logits processor. | ||
xhr15 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| generation_config = Bumblebee.configure(generation_config, max_new_tokens: 2) | ||
|
|
||
| generate = | ||
| Bumblebee.Text.Generation.build_generate(model, spec, generation_config, | ||
| logits_processors: [ | ||
| &Bumblebee.Text.GenerationTest.StatefulLogitsProcessing.stateful_processor(&1, &2, | ||
| initial_suppressed_token_id: [79] | ||
| ) | ||
| ] | ||
| ) | ||
|
|
||
| # The result without the logits processor would be, as with the first | ||
| # decoder test above: 80, 80, 80. | ||
| # | ||
| # Now, with the processor below, we expect no change (suppressed token ID is | ||
| # 79), then a change to another random token ID (176) as the suppressed | ||
| # token ID is incremented from 79 to 80, disallowing the previous most | ||
| # likely token ID (80) from being selected. | ||
|
|
||
| %{token_ids: token_ids} = generate.(params, inputs) | ||
|
|
||
|
|
||
| # first token_id still 80 as we suppress token_id 79 | ||
| assert_equal(token_ids[[0,0]], 80) | ||
| # in the next step we increment from 79 to 80 and suppress token_id 80, the | ||
| #result is 176 as that is the next likelihood in the logits. | ||
|
|
||
| assert_equal(token_ids[[0,1]], 176) | ||
| end | ||
|
|
||
| test "with stateful logits processor with batch size of 2" do | ||
xhr15 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| assert {:ok, %{model: model, params: params, spec: spec}} = | ||
| Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"}) | ||
|
|
||
| {:ok, generation_config} = | ||
| Bumblebee.load_generation_config({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"}) | ||
|
|
||
| assert %Bumblebee.Text.Gpt2{architecture: :for_causal_language_modeling} = spec | ||
|
|
||
| input_ids = Nx.tensor([[0, 0, 10, 20, 30, 40, 50, 60, 70, 80]]) | ||
| attention_mask = Nx.tensor([[0, 0, 1, 1, 1, 1, 1, 1, 1, 1]]) | ||
| seed = Nx.tensor([0]) | ||
|
|
||
| inputs = %{ | ||
| "input_ids" => Nx.Batch.concatenate([input_ids, input_ids]), | ||
| "attention_mask" => Nx.Batch.concatenate([attention_mask, attention_mask]), | ||
| "seed" => Nx.Batch.concatenate([seed, seed]) | ||
| } | ||
|
|
||
| # this is the same example as above, but with a batch size of 2. | ||
|
|
||
|
|
||
| generation_config = Bumblebee.configure(generation_config, max_new_tokens: 3) | ||
|
|
||
| generate = | ||
| Bumblebee.Text.Generation.build_generate(model, spec, generation_config, | ||
| logits_processors: [ | ||
| &Bumblebee.Text.GenerationTest.StatefulLogitsProcessing.stateful_processor(&1, &2, | ||
| initial_suppressed_token_id: [78, 79] | ||
| ) | ||
| ] | ||
| ) | ||
|
|
||
| %{token_ids: token_ids} = generate.(params, inputs) | ||
|
|
||
| # result without logit processor: 80, 80, 80 | ||
|
|
||
| # first entry in batch | ||
| # first token_id still 80 as we suppress token_id 78 | ||
| assert_equal(token_ids[[0, 0]], 80) | ||
| # second token_id still 80 as we suppress token_id 79 | ||
| assert_equal(token_ids[[0, 1]], 80) | ||
| # in the next step we increment from 79 to 80 and suppress token_id 80 | ||
| assert_equal(token_ids[[0, 2]], 1016) | ||
|
|
||
| # second entry in batch | ||
| # first token_id still 80 as we suppress token_id 79 | ||
| assert_equal(token_ids[[1, 0]], 80) | ||
| # in the next step we increment from 79 to 80 and suppress token_id 80 | ||
| assert_equal(token_ids[[1, 1]], 176) | ||
| end | ||
|
|
||
| defmodule StatefulLogitsProcessing do | ||
| import Nx.Defn | ||
|
|
||
| deftransform stateful_processor(logits, context, opts \\ []) do | ||
| initial_suppressed_token_ids = Enum.map(opts[:initial_suppressed_token_id], &List.wrap(&1)) | ||
| initial_suppressed_token_id = Nx.tensor(initial_suppressed_token_ids) |> Nx.vectorize(:batch) | ||
|
|
||
| suppressed_id = | ||
| context.logits_processor_state[:next_suppressed_token_id] || initial_suppressed_token_id | ||
|
|
||
| logits = suppress_id(logits, suppressed_id) | ||
|
|
||
| next_suppressed_token_id = Nx.add(suppressed_id, 1) | ||
|
|
||
| context = | ||
| put_in( | ||
| context, | ||
| [:logits_processor_state, :next_suppressed_token_id], | ||
|
||
| next_suppressed_token_id | ||
| ) | ||
|
|
||
| {logits, context} | ||
| end | ||
|
|
||
| defnp suppress_id(logits, id) do | ||
| Nx.indexed_put( | ||
| logits, | ||
| id, | ||
| Nx.Constants.neg_infinity(Nx.type(logits)) | ||
| ) | ||
| end | ||
| end | ||
| end | ||
Uh oh!
There was an error while loading. Please reload this page.