@@ -127,15 +127,18 @@ def test_value(self, dtype, vsa):
127127 )
128128 else :
129129 hv = functional .circular (8 , 1000000 , vsa , generator = generator , dtype = dtype )
130- sims = functional .cosine_similarity (hv [0 ], hv )
131- sims_diff = sims [:- 1 ] - sims [1 :]
132- assert torch .all (
133- sims_diff .sign () == torch .tensor ([1 , 1 , 1 , 1 , - 1 , - 1 , - 1 ])
134- ), "second half must get more similar"
135-
136- assert torch .allclose (
137- sims_diff .abs (), torch .tensor (0.25 , dtype = sims_diff .dtype ), atol = 0.005
138- ), "similarity decreases linearly"
130+
131+ for i in range (8 - 1 ):
132+ sims = functional .cosine_similarity (hv [0 ], hv )
133+ sims_diff = sims [:- 1 ] - sims [1 :]
134+ assert torch .all (
135+ sims_diff .sign () == torch .tensor ([1 , 1 , 1 , 1 , - 1 , - 1 , - 1 ])
136+ ), f"element #{ i } : second half must get more similar"
137+
138+ assert torch .allclose (
139+ sims_diff .abs (), torch .tensor (0.25 , dtype = sims_diff .dtype ), atol = 0.005
140+ ), f"element #{ i } : similarity decreases linearly"
141+ hv = torch .roll (hv ,1 ,0 )
139142
140143 @pytest .mark .parametrize ("sparsity" , [0.0 , 0.1 , 0.756 , 1.0 ])
141144 @pytest .mark .parametrize ("dtype" , torch_dtypes )
0 commit comments