Skip to content

Commit b832b02

Browse files
bzgooglebzgoogle
andauthored
[CI] fix unit test (#1086)
Signed-off-by: bzgoogle <beinuoz_google_com@t1v-n-fa0da4f0-w-0.us-central1-c.c.cloud-tpu-inference-test.internal> Co-authored-by: bzgoogle <beinuoz_google_com@t1v-n-fa0da4f0-w-0.us-central1-c.c.cloud-tpu-inference-test.internal>
1 parent 3e3f039 commit b832b02

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

tests/layers/jax/attention/test_deepseek_v3_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def test_mla_forward_pass(self, kv_cache_str):
8989
request_distribution=jnp.array([0, 0, 1], dtype=jnp.int32),
9090
)
9191

92-
mla.rope.initialize_cache()
92+
mla.rope.initialize_cache(self.mesh)
9393

9494
# Run forward pass
9595
new_kv_cache, output = mla(x,

tests/layers/jax/test_rope.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import jax
12
from jax import numpy as jnp
23
from jax._src import test_util as jtu
4+
from jax.sharding import Mesh
35

46
from tpu_inference.layers.jax.rope import (DeepseekScalingRotaryEmbedding,
57
RotaryEmbedding)
@@ -42,20 +44,29 @@ def test_apply_rope(self):
4244
rope_theta = 10000
4345
original_max_position_embeddings = 1
4446
scaling_factor = 2
47+
devices = jax.devices()
48+
mesh = Mesh(devices, ('data', ))
49+
4550
rope = DeepseekScalingRotaryEmbedding(
4651
rotary_dim=head_dim,
4752
rope_theta=rope_theta,
4853
original_max_position_embeddings=original_max_position_embeddings,
4954
scaling_factor=scaling_factor,
5055
dtype=jnp.float32)
51-
rope.initialize_cache()
56+
rope.initialize_cache(mesh)
57+
expected_padded_dim = 128
5258
self.assertTrue(
5359
rope.sin_cos_cache.shape == (scaling_factor *
5460
original_max_position_embeddings,
55-
head_dim))
61+
expected_padded_dim))
62+
63+
valid_cache_slice = rope.sin_cos_cache[:, :head_dim]
64+
5665
expected_sin_cos = jnp.array([[1.0693147, 0], [0.5777532, 0.8997973]],
5766
dtype=jnp.float32)
58-
self.assertArraysAllClose(rope.sin_cos_cache, expected_sin_cos)
67+
68+
self.assertArraysAllClose(valid_cache_slice, expected_sin_cos)
69+
5970
num_tokens = 2
6071
num_heads = 1
6172
positions = jnp.arange(num_tokens)

0 commit comments

Comments
 (0)