Skip to content

Commit 5e95eec

Browse files
committed
enh: update tests accordingly
1 parent e1b4498 commit 5e95eec

File tree

4 files changed

+177
-15
lines changed

4 files changed

+177
-15
lines changed

nitransforms/io/itk.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,9 @@ def from_image(cls, imgobj):
347347
warnings.warn("Incorrect intent identified.")
348348
hdr.set_intent("vector")
349349

350-
field = np.squeeze(np.asanyarray(imgobj.dataobj)).transpose(2, 1, 0, 3)
350+
field = np.squeeze(np.asanyarray(imgobj.dataobj))
351+
field[..., (0, 1)] *= 1.0
352+
field = field.transpose(2, 1, 0, 3)
351353
return imgobj.__class__(field, LPS @ imgobj.affine, hdr)
352354

353355
@classmethod
@@ -357,7 +359,9 @@ def to_image(cls, imgobj):
357359
hdr = imgobj.header.copy()
358360
hdr.set_intent("vector")
359361

360-
field = imgobj.get_fdata().transpose(2, 1, 0, 3)[..., None, :]
362+
field = imgobj.get_fdata()
363+
field = field.transpose(2, 1, 0, 3)[..., None, :]
364+
field[..., (0, 1)] *= 1.0
361365
return imgobj.__class__(field, LPS @ imgobj.affine, hdr)
362366

363367

nitransforms/tests/test_io.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -710,7 +710,8 @@ def test_itk_disp_load_intent():
710710

711711
# Added tests for displacements fields orientations (ANTs/ITK)
712712
@pytest.mark.parametrize("image_orientation", ["RAS", "LAS", "LPS", "oblique"])
713-
def test_itk_displacements(tmp_path, get_testdata, image_orientation):
713+
@pytest.mark.parametrize("field_is_random", [False, True])
714+
def test_itk_displacements(tmp_path, get_testdata, image_orientation, field_is_random):
714715
"""Exercise I/O of ITK displacements fields."""
715716

716717
nii = get_testdata[image_orientation]
@@ -719,13 +720,17 @@ def test_itk_displacements(tmp_path, get_testdata, image_orientation):
719720
shape = nii.shape
720721
ref_affine = nii.affine.copy()
721722

722-
field = np.hstack(
723-
(
724-
np.linspace(-50, 50, num=np.prod(shape)),
725-
np.linspace(-80, 80, num=np.prod(shape)),
726-
np.zeros(np.prod(shape)),
727-
)
728-
).reshape(shape + (3,))
723+
field = (
724+
np.hstack(
725+
(
726+
np.linspace(-50, 50, num=np.prod(shape)),
727+
np.linspace(-80, 80, num=np.prod(shape)),
728+
np.zeros(np.prod(shape)),
729+
)
730+
).reshape(shape + (3,))
731+
if not field_is_random
732+
else np.random.normal(size=shape + (3,))
733+
)
729734

730735
nit_nii = itk.ITKDisplacementsField.to_image(
731736
nb.Nifti1Image(field, ref_affine, None)

nitransforms/tests/test_nonlinear.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,15 @@
33
"""Tests of nonlinear transforms."""
44

55
import os
6+
from subprocess import check_call
7+
import shutil
8+
9+
import SimpleITK as sitk
610
import pytest
711

812
import numpy as np
913
import nibabel as nb
14+
from nibabel.affines import from_matvec
1015
from nitransforms.resampling import apply
1116
from nitransforms.base import TransformError
1217
from nitransforms.nonlinear import (
@@ -38,6 +43,30 @@ def test_displacements_init():
3843
)
3944

4045

46+
@pytest.mark.parametrize("is_deltas", [True, False])
47+
def test_densefield_oob_resampling(is_deltas):
48+
"""Ensure mapping outside the field returns input coordinates."""
49+
ref = nb.Nifti1Image(np.zeros((2, 2, 2), dtype="uint8"), np.eye(4))
50+
51+
if is_deltas:
52+
field = nb.Nifti1Image(np.ones((2, 2, 2, 3), dtype="float32"), np.eye(4))
53+
else:
54+
grid = np.stack(
55+
np.meshgrid(*[np.arange(2) for _ in range(3)], indexing="ij"),
56+
axis=-1,
57+
).astype("float32")
58+
field = nb.Nifti1Image(grid + 1.0, np.eye(4))
59+
60+
xfm = DenseFieldTransform(field, is_deltas=is_deltas, reference=ref)
61+
62+
points = np.array([[-1.0, -1.0, -1.0], [0.5, 0.5, 0.5], [3.0, 3.0, 3.0]])
63+
mapped = xfm.map(points)
64+
65+
assert np.allclose(mapped[0], points[0])
66+
assert np.allclose(mapped[2], points[2])
67+
assert np.allclose(mapped[1], points[1] + 1)
68+
69+
4170
def test_bsplines_init():
4271
with pytest.raises(TransformError):
4372
BSplineFieldTransform(
@@ -177,3 +206,128 @@ def manual_map(x):
177206
pts = np.array([[1.2, 1.5, 2.0], [3.3, 1.7, 2.4]])
178207
expected = np.vstack([manual_map(p) for p in pts])
179208
assert np.allclose(bspline.map(pts), expected, atol=1e-6)
209+
210+
211+
def test_densefield_map_against_ants(testdata_path, tmp_path):
212+
"""Map points with DenseFieldTransform and compare to ANTs."""
213+
warpfile = (
214+
testdata_path
215+
/ "regressions"
216+
/ ("01_ants_t1_to_mniComposite_DisplacementFieldTransform.nii.gz")
217+
)
218+
if not warpfile.exists():
219+
pytest.skip("Composite transform test data not available")
220+
221+
points = np.array(
222+
[
223+
[0.0, 0.0, 0.0],
224+
[1.0, 2.0, 3.0],
225+
[10.0, -10.0, 5.0],
226+
[-5.0, 7.0, -2.0],
227+
[-12.0, 12.0, 0.0],
228+
]
229+
)
230+
csvin = tmp_path / "points.csv"
231+
np.savetxt(csvin, points, delimiter=",", header="x,y,z", comments="")
232+
233+
csvout = tmp_path / "out.csv"
234+
cmd = f"antsApplyTransformsToPoints -d 3 -i {csvin} -o {csvout} -t {warpfile}"
235+
exe = cmd.split()[0]
236+
if not shutil.which(exe):
237+
pytest.skip(f"Command {exe} not found on host")
238+
check_call(cmd, shell=True)
239+
240+
ants_res = np.genfromtxt(csvout, delimiter=",", names=True)
241+
ants_pts = np.vstack([ants_res[n] for n in ("x", "y", "z")]).T
242+
243+
xfm = DenseFieldTransform(ITKDisplacementsField.from_filename(warpfile))
244+
mapped = xfm.map(points)
245+
246+
assert np.allclose(mapped, ants_pts, atol=1e-6)
247+
248+
249+
@pytest.mark.parametrize("image_orientation", ["RAS", "LAS", "LPS", "oblique"])
250+
@pytest.mark.parametrize("gridpoints", [True, False])
251+
def test_constant_field_vs_ants(tmp_path, get_testdata, image_orientation, gridpoints):
252+
"""Create a constant displacement field and compare mappings."""
253+
254+
nii = get_testdata[image_orientation]
255+
256+
# Create a reference centered at the origin with various axis orders/flips
257+
shape = nii.shape
258+
ref_affine = nii.affine.copy()
259+
260+
field = np.hstack((
261+
np.zeros(np.prod(shape)),
262+
np.linspace(-80, 80, num=np.prod(shape)),
263+
np.linspace(-50, 50, num=np.prod(shape)),
264+
)).reshape(shape + (3, ))
265+
fieldnii = nb.Nifti1Image(field, ref_affine, None)
266+
267+
warpfile = tmp_path / "itk_transform.nii.gz"
268+
ITKDisplacementsField.to_filename(fieldnii, warpfile)
269+
270+
# Ensure direct (xfm) and ITK roundtrip (itk_xfm) are equivalent
271+
xfm = DenseFieldTransform(fieldnii)
272+
itk_xfm = DenseFieldTransform(ITKDisplacementsField.from_filename(warpfile))
273+
274+
assert xfm == itk_xfm
275+
np.testing.assert_allclose(xfm.reference.affine, itk_xfm.reference.affine)
276+
np.testing.assert_allclose(ref_affine, itk_xfm.reference.affine)
277+
np.testing.assert_allclose(xfm.reference.shape, itk_xfm.reference.shape)
278+
np.testing.assert_allclose(xfm._field, itk_xfm._field)
279+
280+
points = (
281+
xfm.reference.ndcoords.T if gridpoints
282+
else np.array(
283+
[
284+
[0.0, 0.0, 0.0],
285+
[1.0, 2.0, 3.0],
286+
[10.0, -10.0, 5.0],
287+
[-5.0, 7.0, -2.0],
288+
[12.0, 0.0, -11.0],
289+
]
290+
)
291+
)
292+
293+
mapped = xfm.map(points)
294+
nit_deltas = mapped - points
295+
296+
if gridpoints:
297+
np.testing.assert_array_equal(field, nit_deltas.reshape(*shape, -1))
298+
299+
csvin = tmp_path / "points.csv"
300+
np.savetxt(csvin, points, delimiter=",", header="x,y,z", comments="")
301+
302+
csvout = tmp_path / "out.csv"
303+
cmd = f"antsApplyTransformsToPoints -d 3 -i {csvin} -o {csvout} -t {warpfile}"
304+
exe = cmd.split()[0]
305+
if not shutil.which(exe):
306+
pytest.skip(f"Command {exe} not found on host")
307+
check_call(cmd, shell=True)
308+
309+
ants_res = np.genfromtxt(csvout, delimiter=",", names=True)
310+
ants_pts = np.vstack([ants_res[n] for n in ("x", "y", "z")]).T
311+
312+
# if gridpoints:
313+
# ants_field = ants_pts.reshape(shape + (3, ))
314+
# diff = xfm._field[..., 0] - ants_field[..., 0]
315+
# mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
316+
# assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"
317+
318+
# diff = xfm._field[..., 1] - ants_field[..., 1]
319+
# mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
320+
# assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"
321+
322+
# diff = xfm._field[..., 2] - ants_field[..., 2]
323+
# mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
324+
# assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"
325+
326+
ants_deltas = ants_pts - points
327+
np.testing.assert_array_equal(nit_deltas, ants_deltas)
328+
np.testing.assert_array_equal(mapped, ants_pts)
329+
330+
diff = mapped - ants_pts
331+
mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
332+
333+
assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"

nitransforms/tests/test_resampling.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,8 @@ def test_displacements_field1(
192192

193193
xfm = nitnl.load(xfm_fname, fmt=sw_tool)
194194

195+
import pdb; pdb.set_trace()
196+
195197
# Then apply the transform and cross-check with software
196198
cmd = APPLY_NONLINEAR_CMD[sw_tool](
197199
transform=os.path.abspath(xfm_fname),
@@ -247,11 +249,7 @@ def test_displacements_field1(
247249
assert np.sqrt((diff[5:-5, 5:-5, 5:-5] ** 2).mean()) < 1e-6
248250

249251

250-
@pytest.mark.xfail(
251-
reason="Disable while #266 is developed.",
252-
strict=False,
253-
)
254-
@pytest.mark.parametrize("sw_tool", ["itk", "afni"])
252+
@pytest.mark.parametrize("sw_tool", ["afni"])
255253
def test_displacements_field2(tmp_path, testdata_path, sw_tool):
256254
"""Check a translation-only field on one or more axes, different image orientations."""
257255
os.chdir(str(tmp_path))
@@ -283,6 +281,7 @@ def test_displacements_field2(tmp_path, testdata_path, sw_tool):
283281
nt_moved = apply(xfm, img_fname, order=0)
284282
nt_moved.to_filename("nt_resampled.nii.gz")
285283
sw_moved.set_data_dtype(nt_moved.get_data_dtype())
284+
286285
diff = np.asanyarray(
287286
sw_moved.dataobj, dtype=sw_moved.get_data_dtype()
288287
) - np.asanyarray(nt_moved.dataobj, dtype=nt_moved.get_data_dtype())

0 commit comments

Comments
 (0)