Skip to content

Commit 311f77c

Browse files
committed
switch to EXLA as the evaluator is lacking vectorisation support:
``` ** (RuntimeError) unexpected vectorized axes in evaluator for operation :add: #Nx.Tensor< vectorized[batch: 1] s32[1] Nx.Defn.Expr tensor a s32[1] b = reshape a s32[1][1] ```
1 parent 41dd2ad commit 311f77c

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

test/bumblebee/text/generation_test.exs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,8 @@ defmodule Bumblebee.Text.GenerationTest do
156156
#
157157
# Now, with the processor below, we expect the sequence of [79, 80, 81 ..]
158158

159-
%{token_ids: token_ids} = generate.(params, inputs)
159+
%{token_ids: token_ids} =
160+
Nx.Defn.jit_apply(generate, [params, inputs], compiler: EXLA)
160161

161162
# first token_id should be 79 as we enforce token_id 79
162163
assert_equal(token_ids[[0, 0]], 79)
@@ -186,7 +187,8 @@ defmodule Bumblebee.Text.GenerationTest do
186187
]
187188
)
188189

189-
%{token_ids: token_ids} = generate.(params, inputs)
190+
%{token_ids: token_ids} =
191+
Nx.Defn.jit_apply(generate, [params, inputs], compiler: EXLA)
190192

191193
# result without logit processor: 80, 80, 80
192194

0 commit comments

Comments
 (0)