Skip to content

Commit 7845f39

Browse files
author
Donglai Wei
committed
fix build error and formatting
1 parent 7db1d4f commit 7845f39

File tree

23 files changed

+225
-162
lines changed

23 files changed

+225
-162
lines changed

connectomics/config/hydra_config.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,7 +1100,7 @@ class TestDataConfig:
11001100
cache_suffix: str = "_prediction.h5"
11011101
# Image transformation (applied to test images during inference)
11021102
image_transform: ImageTransformConfig = field(default_factory=ImageTransformConfig)
1103-
# Label transformation (optional, typically not used for test mode to preserve raw labels for evaluation)
1103+
# Label transformation (optional). Typically unused in test mode to preserve raw labels
11041104
label_transform: Optional[LabelTransformConfig] = None
11051105

11061106

@@ -1125,7 +1125,7 @@ class TuneDataConfig:
11251125
tune_resolution: Optional[List[int]] = None
11261126
# Image transformation (applied to tune images during inference)
11271127
image_transform: ImageTransformConfig = field(default_factory=ImageTransformConfig)
1128-
# Label transformation (optional, typically not used for tune mode to preserve raw labels for evaluation)
1128+
# Label transformation (optional). Typically unused in tune mode to preserve raw labels
11291129
label_transform: Optional[LabelTransformConfig] = None
11301130

11311131

@@ -1209,36 +1209,6 @@ class TuneConfig:
12091209
parameter_space: ParameterSpaceConfig = field(default_factory=ParameterSpaceConfig)
12101210

12111211

1212-
# Allow safe loading of checkpoints with PyTorch 2.6+ weights_only defaults
1213-
try:
1214-
import torch
1215-
1216-
if hasattr(torch, "serialization") and hasattr(torch.serialization, "add_safe_globals"):
1217-
torch.serialization.add_safe_globals(
1218-
[
1219-
ParameterConfig,
1220-
DecodingParameterSpace,
1221-
PostprocessingParameterSpace,
1222-
ParameterSpaceConfig,
1223-
# Core config dataclasses (for Lightning checkpoints)
1224-
Config,
1225-
SystemConfig,
1226-
SystemTrainingConfig,
1227-
SystemInferenceConfig,
1228-
ModelConfig,
1229-
DataConfig,
1230-
OptimizationConfig,
1231-
MonitorConfig,
1232-
InferenceConfig,
1233-
TestConfig,
1234-
TuneConfig,
1235-
]
1236-
)
1237-
except Exception:
1238-
# Best-effort registration; ignore if torch not available at import time
1239-
pass
1240-
1241-
12421212
@dataclass
12431213
class Config:
12441214
"""Main configuration for PyTorch Connectomics.
@@ -1289,6 +1259,36 @@ class Config:
12891259
tune: Optional[TuneConfig] = None
12901260

12911261

1262+
# Allow safe loading of checkpoints with PyTorch 2.6+ weights_only defaults
1263+
try:
1264+
import torch
1265+
1266+
if hasattr(torch, "serialization") and hasattr(torch.serialization, "add_safe_globals"):
1267+
torch.serialization.add_safe_globals(
1268+
[
1269+
ParameterConfig,
1270+
DecodingParameterSpace,
1271+
PostprocessingParameterSpace,
1272+
ParameterSpaceConfig,
1273+
# Core config dataclasses (for Lightning checkpoints)
1274+
Config,
1275+
SystemConfig,
1276+
SystemTrainingConfig,
1277+
SystemInferenceConfig,
1278+
ModelConfig,
1279+
DataConfig,
1280+
OptimizationConfig,
1281+
MonitorConfig,
1282+
InferenceConfig,
1283+
TestConfig,
1284+
TuneConfig,
1285+
]
1286+
)
1287+
except Exception:
1288+
# Best-effort registration; ignore if torch not available at import time
1289+
pass
1290+
1291+
12921292
# Utility functions for common configuration tasks
12931293
def configure_edge_mode(
12941294
cfg: Config, mode: str = "seg-all", thickness: int = 1, processing_mode: str = "3d"

connectomics/data/augment/build.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,8 @@ def _build_eval_transforms_impl(cfg: Config, mode: str = "val", keys: list[str]
408408
transforms.append(NormalizeLabelsd(keys=["label"]))
409409

410410
# Label transformations (affinity, distance transform, etc.)
411-
# For test/tune modes: NEVER apply label transforms (keep raw instance labels for evaluation)
411+
# For test/tune modes: NEVER apply label transforms
412+
# (keep raw instance labels for evaluation)
412413
# For val mode: use training label_transform config
413414
label_cfg = None
414415
if mode == "val":

connectomics/data/augment/monai_transforms.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,11 @@ def _apply_misalignment_translation(
9999

100100
output = np.zeros(out_shape, img.dtype)
101101
if mode == "slip":
102-
output = img[:, y0 : y0 + out_shape[1], x0 : x0 + out_shape[2]]
103-
output[idx] = img[idx, y1 : y1 + out_shape[1], x1 : x1 + out_shape[2]]
102+
output = img[:, y0:y0 + out_shape[1], x0:x0 + out_shape[2]]
103+
output[idx] = img[idx, y1:y1 + out_shape[1], x1:x1 + out_shape[2]]
104104
else:
105-
output[:idx] = img[:idx, y0 : y0 + out_shape[1], x0 : x0 + out_shape[2]]
106-
output[idx:] = img[idx:, y1 : y1 + out_shape[1], x1 : x1 + out_shape[2]]
105+
output[:idx] = img[:idx, y0:y0 + out_shape[1], x0:x0 + out_shape[2]]
106+
output[idx:] = img[idx:, y1:y1 + out_shape[1], x1:x1 + out_shape[2]]
107107

108108
if is_tensor:
109109
output = torch.from_numpy(output).to(device)
@@ -299,7 +299,7 @@ def _apply_missing_parts(
299299
x_start = self.R.randint(0, img.shape[2] - hole_w + 1)
300300

301301
# Create hole (set to 0 or mean value)
302-
img[section_idx, y_start : y_start + hole_h, x_start : x_start + hole_w] = 0
302+
img[section_idx, y_start: y_start + hole_h, x_start: x_start + hole_w] = 0
303303

304304
return img
305305

@@ -452,24 +452,24 @@ def _apply_cut_noise(
452452
noise = self.R.uniform(-self.noise_scale, self.noise_scale, noise_shape)
453453
region = img[
454454
:,
455-
z_start : z_start + z_len,
456-
y_start : y_start + y_len,
457-
x_start : x_start + x_len,
455+
z_start: z_start + z_len,
456+
y_start: y_start + y_len,
457+
x_start: x_start + x_len,
458458
]
459459
noisy_region = np.clip(region + noise, 0, 1)
460460
img[
461461
:,
462-
z_start : z_start + z_len,
463-
y_start : y_start + y_len,
464-
x_start : x_start + x_len,
462+
z_start: z_start + z_len,
463+
y_start: y_start + y_len,
464+
x_start: x_start + x_len,
465465
] = noisy_region
466466
else:
467467
# (C, H, W) - 2D with channels
468468
noise_shape = (img.shape[0], y_len, x_len)
469469
noise = self.R.uniform(-self.noise_scale, self.noise_scale, noise_shape)
470-
region = img[:, y_start : y_start + y_len, x_start : x_start + x_len]
470+
region = img[:, y_start: y_start + y_len, x_start: x_start + x_len]
471471
noisy_region = np.clip(region + noise, 0, 1)
472-
img[:, y_start : y_start + y_len, x_start : x_start + x_len] = noisy_region
472+
img[:, y_start: y_start + y_len, x_start: x_start + x_len] = noisy_region
473473
elif img.ndim == 3:
474474
# 3D case: (Z, Y, X) or (C, H, W)
475475
# Heuristic: if first dim is small (<=4), assume it's channel (2D with channels)
@@ -478,29 +478,29 @@ def _apply_cut_noise(
478478
# (C, H, W) - 2D with channels
479479
noise_shape = (img.shape[0], y_len, x_len)
480480
noise = self.R.uniform(-self.noise_scale, self.noise_scale, noise_shape)
481-
region = img[:, y_start : y_start + y_len, x_start : x_start + x_len]
481+
region = img[:, y_start: y_start + y_len, x_start: x_start + x_len]
482482
noisy_region = np.clip(region + noise, 0, 1)
483-
img[:, y_start : y_start + y_len, x_start : x_start + x_len] = noisy_region
483+
img[:, y_start: y_start + y_len, x_start: x_start + x_len] = noisy_region
484484
else:
485485
# (Z, Y, X) - 3D
486486
z_len = max(1, int(self.length_ratio * img.shape[0])) # Ensure at least 1
487487
z_start = self.R.randint(0, max(1, img.shape[0] - z_len + 1))
488488
noise_shape = (z_len, y_len, x_len)
489489
noise = self.R.uniform(-self.noise_scale, self.noise_scale, noise_shape)
490490
region = img[
491-
z_start : z_start + z_len, y_start : y_start + y_len, x_start : x_start + x_len
491+
z_start: z_start + z_len, y_start: y_start + y_len, x_start: x_start + x_len
492492
]
493493
noisy_region = np.clip(region + noise, 0, 1)
494494
img[
495-
z_start : z_start + z_len, y_start : y_start + y_len, x_start : x_start + x_len
495+
z_start: z_start + z_len, y_start: y_start + y_len, x_start: x_start + x_len
496496
] = noisy_region
497497
else:
498498
# 2D case: (H, W)
499499
noise_shape = (y_len, x_len)
500500
noise = self.R.uniform(-self.noise_scale, self.noise_scale, noise_shape)
501-
region = img[y_start : y_start + y_len, x_start : x_start + x_len]
501+
region = img[y_start: y_start + y_len, x_start: x_start + x_len]
502502
noisy_region = np.clip(region + noise, 0, 1)
503-
img[y_start : y_start + y_len, x_start : x_start + x_len] = noisy_region
503+
img[y_start: y_start + y_len, x_start: x_start + x_len] = noisy_region
504504

505505
if is_tensor:
506506
img = torch.from_numpy(img).to(device)
@@ -886,7 +886,7 @@ def _find_best_paste(
886886
neuron_tensor.flip(0) if neuron_tensor.ndim == 3 else neuron_tensor.flip(1)
887887
)
888888

889-
label_paste = labels[best_idx : best_idx + 1]
889+
label_paste = labels[best_idx: best_idx + 1]
890890

891891
if best_angle != 0:
892892
label_paste = self._rotate_3d(label_paste, best_angle)

connectomics/data/io/io.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def read_hdf5(
3737
Args:
3838
filename: Path to the HDF5 file
3939
dataset: Name of the dataset to read. If None, reads the first dataset
40-
slice_obj: Optional slice for partial loading (e.g., np.s_[0:10, :, :])
40+
slice_obj: Optional slice for partial loading (e.g., np.s_[0:10,:,:])
4141
4242
Returns:
4343
Data from the HDF5 file as numpy array
@@ -171,7 +171,7 @@ def read_image_as_volume(filename: str, drop_channel: bool = False) -> np.ndarra
171171
Raises:
172172
ValueError: If file format is not supported
173173
"""
174-
image_suffix = filename[filename.rfind(".") + 1 :].lower()
174+
image_suffix = filename[filename.rfind(".") + 1:].lower()
175175
if image_suffix not in SUPPORTED_IMAGE_FORMATS:
176176
raise ValueError(
177177
f"Unsupported format: {image_suffix}. Supported formats: {SUPPORTED_IMAGE_FORMATS}"
@@ -281,7 +281,7 @@ def read_volume(
281281
if filename.endswith(".nii.gz"):
282282
image_suffix = "nii.gz"
283283
else:
284-
image_suffix = filename[filename.rfind(".") + 1 :].lower()
284+
image_suffix = filename[filename.rfind(".") + 1:].lower()
285285

286286
if image_suffix in ["h5", "hdf5"]:
287287
data = read_hdf5(filename, dataset)
@@ -420,7 +420,7 @@ def get_vol_shape(filename: str, dataset: Optional[str] = None) -> tuple:
420420
if filename.endswith(".nii.gz"):
421421
image_suffix = "nii.gz"
422422
else:
423-
image_suffix = filename[filename.rfind(".") + 1 :].lower()
423+
image_suffix = filename[filename.rfind(".") + 1:].lower()
424424

425425
if image_suffix in ["h5", "hdf5"]:
426426
# HDF5: Read shape from metadata (no data loading)

connectomics/data/io/tiles.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,22 +160,22 @@ def reconstruct_volume_from_tiles(
160160
if is_image: # Image data
161161
result[
162162
z - z0,
163-
y_actual_start - y0 : y_actual_end - y0,
164-
x_actual_start - x0 : x_actual_end - x0,
163+
y_actual_start - y0: y_actual_end - y0,
164+
x_actual_start - x0: x_actual_end - x0,
165165
] = patch[
166-
y_actual_start - y_patch_start : y_actual_end - y_patch_start,
167-
x_actual_start - x_patch_start : x_actual_end - x_patch_start,
166+
y_actual_start - y_patch_start: y_actual_end - y_patch_start,
167+
x_actual_start - x_patch_start: x_actual_end - x_patch_start,
168168
0,
169169
]
170170
else: # Label data
171171
result[
172172
z - z0,
173-
y_actual_start - y0 : y_actual_end - y0,
174-
x_actual_start - x0 : x_actual_end - x0,
173+
y_actual_start - y0: y_actual_end - y0,
174+
x_actual_start - x0: x_actual_end - x0,
175175
] = rgb_to_seg(
176176
patch[
177-
y_actual_start - y_patch_start : y_actual_end - y_patch_start,
178-
x_actual_start - x_patch_start : x_actual_end - x_patch_start,
177+
y_actual_start - y_patch_start: y_actual_end - y_patch_start,
178+
x_actual_start - x_patch_start: x_actual_end - x_patch_start,
179179
]
180180
)
181181

connectomics/data/process/crop.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ def crop_volume(data, sz, st=(0, 0, 0)):
1616
st = np.array(st).astype(np.int32)
1717

1818
if data.ndim == 3:
19-
return data[st[0] : st[0] + sz[0], st[1] : st[1] + sz[1], st[2] : st[2] + sz[2]]
19+
return data[st[0]:st[0] + sz[0], st[1]:st[1] + sz[1], st[2]:st[2] + sz[2]]
2020
else: # crop spatial dimensions
21-
return data[:, st[0] : st[0] + sz[0], st[1] : st[1] + sz[1], st[2] : st[2] + sz[2]]
21+
return data[:, st[0]:st[0] + sz[0], st[1]:st[1] + sz[1], st[2]:st[2] + sz[2]]
2222

2323

2424
def get_valid_pos_torch(mask, vol_sz, valid_ratio):
@@ -64,9 +64,9 @@ def get_valid_pos(mask, vol_sz, valid_ratio):
6464
if len(vol_sz) == 3:
6565
mask_sum = (
6666
mask_sum[
67-
pad_sz_pre[0] : pad_sz_post[0],
68-
pad_sz_pre[1] : pad_sz_post[1],
69-
pad_sz_pre[2] : pad_sz_post[2],
67+
pad_sz_pre[0]:pad_sz_post[0],
68+
pad_sz_pre[1]:pad_sz_post[1],
69+
pad_sz_pre[2]:pad_sz_post[2],
7070
]
7171
>= valid_thres
7272
)
@@ -86,7 +86,7 @@ def get_valid_pos(mask, vol_sz, valid_ratio):
8686
)
8787
else:
8888
mask_sum = (
89-
mask_sum[pad_sz_pre[0] : pad_sz_post[0], pad_sz_pre[1] : pad_sz_post[1]] >= valid_thres
89+
mask_sum[pad_sz_pre[0]:pad_sz_post[0], pad_sz_pre[1]:pad_sz_post[1]] >= valid_thres
9090
)
9191
if mask_sum.max() > 0:
9292
yy, xx = np.meshgrid(np.arange(mask_sum.shape[0]), np.arange(mask_sum.shape[1]))

connectomics/decoding/optuna_tuner.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,10 @@ def _validate_data(self):
175175
)
176176
if mask.shape != gt.shape:
177177
raise ValueError(
178-
f"Mask shape {mask.shape} doesn't match ground truth shape {gt.shape} in volume {i}"
178+
"Mask shape {mask_shape} doesn't match ground truth shape "
179+
"{gt_shape} in volume {vol_idx}".format(
180+
mask_shape=mask.shape, gt_shape=gt.shape, vol_idx=i
181+
)
179182
)
180183

181184
def optimize(self) -> optuna.Study:
@@ -351,7 +354,10 @@ def _objective(self, trial: optuna.Trial) -> float:
351354
import traceback
352355

353356
print(
354-
f"\n❌ Trial {self.trial_count} failed during post-processing (volume {vol_idx}):"
357+
"\n❌ Trial {count} failed during post-processing "
358+
"(volume {vol_idx}):".format(
359+
count=self.trial_count, vol_idx=vol_idx
360+
)
355361
)
356362
print(f" Parameters: {postproc_params}")
357363
print(f" Error: {e}")
@@ -374,7 +380,10 @@ def _objective(self, trial: optuna.Trial) -> float:
374380
import traceback
375381

376382
print(
377-
f"\n❌ Trial {self.trial_count} failed during metric computation (volume {vol_idx}):"
383+
"\n❌ Trial {count} failed during metric computation "
384+
"(volume {vol_idx}):".format(
385+
count=self.trial_count, vol_idx=vol_idx
386+
)
378387
)
379388
print(f" Metric: {metric_name}")
380389
print(f" Segmentation shape: {segmentation.shape}, dtype: {segmentation.dtype}")
@@ -396,7 +405,13 @@ def _objective(self, trial: optuna.Trial) -> float:
396405
# Show per-volume and average
397406
vol_str = ", ".join([f"{m:.4f}" for m in volume_metrics])
398407
print(
399-
f"Trial {self.trial_count:3d}: {metric_name}=[{vol_str}] avg={metric_value:.6f} ({direction})"
408+
"Trial {count:3d}: {metric}=[{volumes}] avg={avg:.6f} ({direction})".format(
409+
count=self.trial_count,
410+
metric=metric_name,
411+
volumes=vol_str,
412+
avg=metric_value,
413+
direction=direction,
414+
)
400415
)
401416

402417
return metric_value
@@ -778,9 +793,8 @@ def run_tuning(model, trainer, cfg, checkpoint_path=None):
778793

779794
print(f"✓ Loaded {len(predictions)} prediction volume(s)")
780795
for i, pred in enumerate(predictions):
781-
print(
782-
f" Volume {i}: shape {pred.shape}, dtype {pred.dtype}, range [{pred.min():.3f}, {pred.max():.3f}]"
783-
)
796+
pred_range = f"[{pred.min():.3f}, {pred.max():.3f}]"
797+
print(f" Volume {i}: shape {pred.shape}, dtype {pred.dtype}, range {pred_range}")
784798

785799
# Step 3: Load ground truth
786800
print("\n[3/4] Loading ground truth labels...")
@@ -816,8 +830,10 @@ def run_tuning(model, trainer, cfg, checkpoint_path=None):
816830
for i, gt in enumerate(ground_truth):
817831
unique_labels = np.unique(gt)
818832
n_nonzero_labels = len(unique_labels) - (1 if 0 in unique_labels else 0)
833+
gt_range = f"[{gt.min()}, {gt.max()}]"
819834
print(
820-
f" Volume {i}: shape {gt.shape}, dtype {gt.dtype}, range [{gt.min()}, {gt.max()}], unique labels: {n_nonzero_labels}"
835+
f" Volume {i}: shape {gt.shape}, dtype {gt.dtype}, range {gt_range}, "
836+
f"unique labels: {n_nonzero_labels}"
821837
)
822838

823839
# Validate ground truth for this volume

0 commit comments

Comments
 (0)