Skip to content

Commit 8365426

Browse files
xhr15joelpaulkochjonatanklosko
authored
Add state to logits processing (#425)
Co-authored-by: Joel Koch <joel@bitcrowd.net> Co-authored-by: Jonatan Kłosko <jonatanklosko@gmail.com>
1 parent 9688192 commit 8365426

File tree

5 files changed

+360
-34
lines changed

5 files changed

+360
-34
lines changed

lib/bumblebee.ex

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,6 +1089,36 @@ defmodule Bumblebee do
10891089
end
10901090
end
10911091

1092+
@doc """
1093+
Initializes state for a new logits processor.
1094+
1095+
Returns `state`, which is an opaque `Nx.Container`, and it is then
1096+
passed to and returned from `process/4`.
1097+
"""
1098+
@doc type: :logits_processor
1099+
@spec logits_processor_init(
1100+
Bumblebee.LogitsProcessor.t(),
1101+
context :: Bumblebee.LogitsProcessor.init_context()
1102+
) :: Bumblebee.LogitsProcessor.state()
1103+
def logits_processor_init(%module{} = logits_processor, context) do
1104+
module.init(logits_processor, context)
1105+
end
1106+
1107+
@doc """
1108+
Processes logits, applying specific rules. Receives context, state and
1109+
logits, and returns updated logits and state.
1110+
"""
1111+
@doc type: :logits_processor
1112+
@spec logits_processor_process(
1113+
Bumblebee.LogitsProcessor.t(),
1114+
Bumblebee.LogitsProcessor.state(),
1115+
logits :: Nx.Tensor.t(),
1116+
context :: Bumblebee.LogitsProcessor.process_context()
1117+
) :: {Bumblebee.LogitsProcessor.state(), logits :: Nx.Tensor.t()}
1118+
def logits_processor_process(%module{} = logits_processor, state, logits, context) do
1119+
module.process(logits_processor, state, logits, context)
1120+
end
1121+
10921122
@doc """
10931123
Initializes state for a new scheduler loop.
10941124

lib/bumblebee/logits_processor.ex

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
defmodule Bumblebee.LogitsProcessor do
2+
@moduledoc """
3+
An interface for configuring and using logits processors.
4+
5+
Logits processors are used during autoregressive generation to modify
6+
predicted scores at each generation step. This allows for applying
7+
certain rules to the model output to control which tokens are picked
8+
at each generation step, and which are not.
9+
10+
Every module implementing this behaviour is expected to also define
11+
a configuration struct.
12+
"""
13+
14+
@type t :: Bumblebee.Configurable.t()
15+
16+
@type state :: Nx.Container.t()
17+
18+
@type process_context :: %{
19+
sequence: Nx.Tensor.t(),
20+
length: Nx.Tensor.t(),
21+
input_length: Nx.Tensor.t()
22+
}
23+
24+
@type init_context :: %{}
25+
26+
@doc """
27+
Initializes state for a new logits processor.
28+
29+
Returns `state`, which is an opaque `Nx.Container`, and it is then
30+
passed to and returned from `process/2`.
31+
32+
Oftentimes logits processors are stateless, in which case this
33+
function can return an empty container, such as `{}`.
34+
"""
35+
@callback init(t(), init_context()) :: state()
36+
37+
@doc """
38+
Processes logits, applying specific rules.
39+
"""
40+
@callback process(
41+
t(),
42+
state(),
43+
logits :: Nx.Tensor.t(),
44+
context :: process_context()
45+
) :: {state :: map(), logits :: Nx.Tensor.t()}
46+
end

0 commit comments

Comments
 (0)