|
| 1 | +import jax |
1 | 2 | from jax import numpy as jnp |
2 | 3 | from jax._src import test_util as jtu |
| 4 | +from jax.sharding import Mesh |
3 | 5 |
|
4 | 6 | from tpu_inference.layers.jax.rope import (DeepseekScalingRotaryEmbedding, |
5 | 7 | RotaryEmbedding) |
@@ -42,20 +44,29 @@ def test_apply_rope(self): |
42 | 44 | rope_theta = 10000 |
43 | 45 | original_max_position_embeddings = 1 |
44 | 46 | scaling_factor = 2 |
| 47 | + devices = jax.devices() |
| 48 | + mesh = Mesh(devices, ('data', )) |
| 49 | + |
45 | 50 | rope = DeepseekScalingRotaryEmbedding( |
46 | 51 | rotary_dim=head_dim, |
47 | 52 | rope_theta=rope_theta, |
48 | 53 | original_max_position_embeddings=original_max_position_embeddings, |
49 | 54 | scaling_factor=scaling_factor, |
50 | 55 | dtype=jnp.float32) |
51 | | - rope.initialize_cache() |
| 56 | + rope.initialize_cache(mesh) |
| 57 | + expected_padded_dim = 128 |
52 | 58 | self.assertTrue( |
53 | 59 | rope.sin_cos_cache.shape == (scaling_factor * |
54 | 60 | original_max_position_embeddings, |
55 | | - head_dim)) |
| 61 | + expected_padded_dim)) |
| 62 | + |
| 63 | + valid_cache_slice = rope.sin_cos_cache[:, :head_dim] |
| 64 | + |
56 | 65 | expected_sin_cos = jnp.array([[1.0693147, 0], [0.5777532, 0.8997973]], |
57 | 66 | dtype=jnp.float32) |
58 | | - self.assertArraysAllClose(rope.sin_cos_cache, expected_sin_cos) |
| 67 | + |
| 68 | + self.assertArraysAllClose(valid_cache_slice, expected_sin_cos) |
| 69 | + |
59 | 70 | num_tokens = 2 |
60 | 71 | num_heads = 1 |
61 | 72 | positions = jnp.arange(num_tokens) |
|
0 commit comments