Skip to content

Commit 5d151ff

Browse files
Add SmolLM3 (#422)
Co-authored-by: Jonatan Kłosko <jonatanklosko@gmail.com>
1 parent bc1b452 commit 5d151ff

File tree

7 files changed

+819
-2
lines changed

7 files changed

+819
-2
lines changed

lib/bumblebee.ex

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,11 @@ defmodule Bumblebee do
188188
"RobertaForTokenClassification" => {Bumblebee.Text.Roberta, :for_token_classification},
189189
"RobertaForCausalLM" => {Bumblebee.Text.Roberta, :for_causal_language_modeling},
190190
"RobertaModel" => {Bumblebee.Text.Roberta, :base},
191+
"SmolLM3Model" => {Bumblebee.Text.SmolLM3, :base},
192+
"SmolLM3ForCausalLM" => {Bumblebee.Text.SmolLM3, :for_causal_language_modeling},
193+
"SmolLM3ForQuestionAnswering" => {Bumblebee.Text.SmolLM3, :for_question_answering},
194+
"SmolLM3ForSequenceClassification" => {Bumblebee.Text.SmolLM3, :for_sequence_classification},
195+
"SmolLM3ForTokenClassification" => {Bumblebee.Text.SmolLM3, :for_token_classification},
191196
"SwinModel" => {Bumblebee.Vision.Swin, :base},
192197
"SwinForImageClassification" => {Bumblebee.Vision.Swin, :for_image_classification},
193198
"T5Model" => {Bumblebee.Text.T5, :base},
@@ -254,6 +259,7 @@ defmodule Bumblebee do
254259
"phi" => :code_gen,
255260
"phi3" => :llama,
256261
"roberta" => :roberta,
262+
"smollm3" => :smollm3,
257263
"t5" => :t5,
258264
"whisper" => :whisper,
259265
"xlm-roberta" => :xlm_roberta,

lib/bumblebee/layers/transformer.ex

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ defmodule Bumblebee.Layers.Transformer do
2121
is configured, this option controls whether the bias from the
2222
first block is used for all other blocks. Defaults to `false`
2323
24+
* `:rotary_embedding` - configuration of rotary embedding. Can be:
25+
- a keyword list (applied to all blocks)
26+
- a function that takes the block index and returns the configuration
27+
2428
* `:name` - the prefix for layer names
2529
2630
For all other options (including required options) see `block/2`.
@@ -49,8 +53,7 @@ defmodule Bumblebee.Layers.Transformer do
4953
:layer_norm,
5054
:block_type,
5155
:attention_window_size,
52-
:scale_attention_weights,
53-
:rotary_embedding
56+
:scale_attention_weights
5457
]
5558

5659
opts =
@@ -60,6 +63,7 @@ defmodule Bumblebee.Layers.Transformer do
6063
[
6164
:name,
6265
:num_blocks,
66+
:rotary_embedding,
6367
attention_mask: Layers.none(),
6468
attention_head_mask: Layers.none(),
6569
attention_relative_bias: nil,
@@ -80,6 +84,7 @@ defmodule Bumblebee.Layers.Transformer do
8084
cross_attention_mask = opts[:cross_attention_mask]
8185
cross_attention_head_mask = opts[:cross_attention_head_mask]
8286
cache = opts[:cache]
87+
rotary_embedding = opts[:rotary_embedding]
8388

8489
block_opts = Keyword.take(opts, block_opts_keys)
8590

@@ -109,6 +114,13 @@ defmodule Bumblebee.Layers.Transformer do
109114
opts[:attention_relative_bias] || Layers.none()
110115
end
111116

117+
block_rotary_embedding =
118+
case rotary_embedding do
119+
nil -> nil
120+
fun when is_function(fun, 1) -> fun.(idx)
121+
config when is_list(config) -> config
122+
end
123+
112124
{hidden_state, attention, cross_attention, block_cache, attention_relative_bias} =
113125
block(
114126
state.hidden_state,
@@ -121,6 +133,7 @@ defmodule Bumblebee.Layers.Transformer do
121133
cross_attention_head_mask: block_cross_attention_head_mask,
122134
block_cache: block_cache,
123135
offset: offset,
136+
rotary_embedding: block_rotary_embedding,
124137
name: join(name, idx)
125138
] ++ block_opts
126139
)

lib/bumblebee/text/pre_trained_tokenizer.ex

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,12 @@ defmodule Bumblebee.Text.PreTrainedTokenizer do
211211
mask: "<mask>"
212212
}
213213
},
214+
smollm3: %{
215+
special_tokens: %{
216+
eos: "<|im_end|>",
217+
pad: "<|im_end|>"
218+
}
219+
},
214220
t5: %{
215221
special_tokens: %{
216222
bos: "<s>",

0 commit comments

Comments
 (0)