@@ -21,6 +21,10 @@ defmodule Bumblebee.Layers.Transformer do
2121 is configured, this option controls whether the bias from the
2222 first block is used for all other blocks. Defaults to `false`
2323
24+ * `:rotary_embedding` - configuration of rotary embedding. Can be:
25+ - a keyword list (applied to all blocks)
26+ - a function that takes the block index and returns the configuration
27+
2428 * `:name` - the prefix for layer names
2529
2630 For all other options (including required options) see `block/2`.
@@ -49,8 +53,7 @@ defmodule Bumblebee.Layers.Transformer do
4953 :layer_norm ,
5054 :block_type ,
5155 :attention_window_size ,
52- :scale_attention_weights ,
53- :rotary_embedding
56+ :scale_attention_weights
5457 ]
5558
5659 opts =
@@ -60,6 +63,7 @@ defmodule Bumblebee.Layers.Transformer do
6063 [
6164 :name ,
6265 :num_blocks ,
66+ :rotary_embedding ,
6367 attention_mask: Layers . none ( ) ,
6468 attention_head_mask: Layers . none ( ) ,
6569 attention_relative_bias: nil ,
@@ -80,6 +84,7 @@ defmodule Bumblebee.Layers.Transformer do
8084 cross_attention_mask = opts [ :cross_attention_mask ]
8185 cross_attention_head_mask = opts [ :cross_attention_head_mask ]
8286 cache = opts [ :cache ]
87+ rotary_embedding = opts [ :rotary_embedding ]
8388
8489 block_opts = Keyword . take ( opts , block_opts_keys )
8590
@@ -109,6 +114,13 @@ defmodule Bumblebee.Layers.Transformer do
109114 opts [ :attention_relative_bias ] || Layers . none ( )
110115 end
111116
117+ block_rotary_embedding =
118+ case rotary_embedding do
119+ nil -> nil
120+ fun when is_function ( fun , 1 ) -> fun . ( idx )
121+ config when is_list ( config ) -> config
122+ end
123+
112124 { hidden_state , attention , cross_attention , block_cache , attention_relative_bias } =
113125 block (
114126 state . hidden_state ,
@@ -121,6 +133,7 @@ defmodule Bumblebee.Layers.Transformer do
121133 cross_attention_head_mask: block_cross_attention_head_mask ,
122134 block_cache: block_cache ,
123135 offset: offset ,
136+ rotary_embedding: block_rotary_embedding ,
124137 name: join ( name , idx )
125138 ] ++ block_opts
126139 )
0 commit comments