Skip to content

Commit eaa8472

Browse files
author
Donglai Wei
committed
fix formatting and test files
1 parent dcf0bf5 commit eaa8472

26 files changed

+300
-422
lines changed

connectomics/config/auto_config.py

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -377,26 +377,24 @@ def auto_plan_config(
377377

378378
# Collect manual overrides (values explicitly set in config)
379379
manual_overrides = {}
380+
training_cfg = getattr(config.system, "training", None) if hasattr(config, "system") else None
380381
if hasattr(config, "data"):
381-
if hasattr(config.data, "batch_size") and config.data.batch_size is not None:
382-
manual_overrides["batch_size"] = config.data.batch_size
383-
if hasattr(config.data, "num_workers") and config.data.num_workers is not None:
384-
manual_overrides["num_workers"] = config.data.num_workers
382+
if training_cfg and getattr(training_cfg, "batch_size", None) is not None:
383+
manual_overrides["batch_size"] = training_cfg.batch_size
384+
if training_cfg and getattr(training_cfg, "num_workers", None) is not None:
385+
manual_overrides["num_workers"] = training_cfg.num_workers
385386
if hasattr(config.data, "patch_size") and config.data.patch_size is not None:
386387
manual_overrides["patch_size"] = config.data.patch_size
387388

388-
if hasattr(config, "training"):
389-
if hasattr(config.training, "precision") and config.training.precision is not None:
390-
manual_overrides["precision"] = config.training.precision
391-
if (
392-
hasattr(config.training, "accumulate_grad_batches")
393-
and config.training.accumulate_grad_batches is not None
394-
):
395-
manual_overrides["accumulate_grad_batches"] = config.training.accumulate_grad_batches
389+
if hasattr(config, "optimization"):
390+
if getattr(config.optimization, "precision", None) is not None:
391+
manual_overrides["precision"] = config.optimization.precision
392+
if getattr(config.optimization, "accumulate_grad_batches", None) is not None:
393+
manual_overrides["accumulate_grad_batches"] = config.optimization.accumulate_grad_batches
396394

397-
if hasattr(config, "optimizer"):
398-
if hasattr(config.optimizer, "lr") and config.optimizer.lr is not None:
399-
manual_overrides["lr"] = config.optimizer.lr
395+
opt_cfg = getattr(config.optimization, "optimizer", None)
396+
if opt_cfg and getattr(opt_cfg, "lr", None) is not None:
397+
manual_overrides["lr"] = opt_cfg.lr
400398

401399
# Create planner
402400
planner = AutoConfigPlanner(
@@ -408,9 +406,8 @@ def auto_plan_config(
408406

409407
# Plan
410408
use_mixed_precision = not (
411-
hasattr(config, "training")
412-
and hasattr(config.training, "precision")
413-
and config.training.precision == "32"
409+
hasattr(config, "optimization")
410+
and getattr(config.optimization, "precision", None) == "32"
414411
)
415412

416413
result = planner.plan(
@@ -423,20 +420,20 @@ def auto_plan_config(
423420
# Update config with planned values (if not manually overridden)
424421
OmegaConf.set_struct(config, False) # Allow adding new fields
425422

426-
if "batch_size" not in manual_overrides:
427-
config.data.batch_size = result.batch_size
428-
if "num_workers" not in manual_overrides:
429-
config.data.num_workers = result.num_workers
423+
if "batch_size" not in manual_overrides and training_cfg is not None:
424+
training_cfg.batch_size = result.batch_size
425+
if "num_workers" not in manual_overrides and training_cfg is not None:
426+
training_cfg.num_workers = result.num_workers
430427
if "patch_size" not in manual_overrides:
431428
config.data.patch_size = result.patch_size
432429

433430
if "precision" not in manual_overrides:
434-
config.training.precision = result.precision
431+
config.optimization.precision = result.precision
435432
if "accumulate_grad_batches" not in manual_overrides:
436-
config.training.accumulate_grad_batches = result.accumulate_grad_batches
433+
config.optimization.accumulate_grad_batches = result.accumulate_grad_batches
437434

438-
if "lr" not in manual_overrides:
439-
config.optimizer.lr = result.lr
435+
if "lr" not in manual_overrides and hasattr(config, "optimization"):
436+
config.optimization.optimizer.lr = result.lr
440437

441438
OmegaConf.set_struct(config, True) # Re-enable struct mode
442439

@@ -460,7 +457,7 @@ def auto_plan_config(
460457
cfg = auto_plan_config(cfg, print_results=True)
461458

462459
print("\nFinal Config Values:")
463-
print(f" batch_size: {cfg.data.batch_size}")
460+
print(f" batch_size: {cfg.system.training.batch_size}")
464461
print(f" patch_size: {cfg.data.patch_size}")
465462
print(f" precision: {cfg.optimization.precision}")
466463
print(f" lr: {cfg.optimization.optimizer.lr}")

connectomics/config/hydra_config.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -450,9 +450,7 @@ class DataConfig:
450450
label_transform: LabelTransformConfig = field(default_factory=LabelTransformConfig)
451451

452452
# Augmentation configuration (nested under data in YAML)
453-
augmentation: Optional["AugmentationConfig"] = (
454-
None # Set to None for simple enabled flag, or full config for detailed control
455-
)
453+
augmentation: "AugmentationConfig" = field(default_factory=lambda: AugmentationConfig())
456454

457455

458456
@dataclass

connectomics/config/hydra_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def update_from_cli(cfg: Config, overrides: List[str]) -> Config:
8686
"""
8787
Update config from command-line overrides.
8888
89-
Supports dot notation: ['data.batch_size=4', 'model.architecture=unet3d']
89+
Supports dot notation: ['system.training.batch_size=4', 'model.architecture=unet3d']
9090
9191
Args:
9292
cfg: Base Config object

connectomics/data/augment/monai_transforms.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,11 @@ def __init__(
3838
self.rotate_ratio = rotate_ratio
3939

4040
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
41+
if self.prob <= 0:
42+
self._do_transform = False
43+
return data
4144
d = dict(data)
45+
self.randomize(None)
4246
if not self._do_transform:
4347
return d
4448

@@ -53,6 +57,10 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
5357
d[key] = self._apply_misalignment_translation(d[key])
5458
return d
5559

60+
def randomize(self, _: Any = None) -> None:
61+
"""Randomly decide whether to apply the transform."""
62+
self._do_transform = self.R.rand() < self.prob
63+
5664
def _apply_misalignment_translation(
5765
self, img: Union[np.ndarray, torch.Tensor]
5866
) -> Union[np.ndarray, torch.Tensor]:
@@ -184,7 +192,12 @@ def __init__(
184192
self.num_sections = num_sections
185193

186194
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
195+
if self.prob <= 0:
196+
self._do_transform = False
197+
return data
198+
187199
d = dict(data)
200+
self.randomize(None)
188201
if not self._do_transform:
189202
return d
190203

@@ -193,29 +206,40 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
193206
d[key] = self._apply_missing_section(d[key])
194207
return d
195208

209+
def randomize(self, _: Any = None) -> None:
210+
"""Randomly decide whether to apply the transform."""
211+
self._do_transform = self.R.rand() < self.prob
212+
196213
def _apply_missing_section(
197214
self, img: Union[np.ndarray, torch.Tensor]
198215
) -> Union[np.ndarray, torch.Tensor]:
199216
"""Remove random sections from volume."""
200-
if img.ndim < 3 or img.shape[0] <= 3:
217+
if img.ndim < 3:
201218
return img # Skip 2D or very small volumes
202219

203220
# Handle both numpy and torch tensors
204221
is_tensor = isinstance(img, torch.Tensor)
205222

223+
depth_axis = 0
224+
if img.ndim >= 4 and img.shape[0] <= 3:
225+
depth_axis = 1 # Channel-first; depth is axis 1
226+
227+
depth = img.shape[depth_axis]
228+
if depth <= 3:
229+
return img
230+
206231
# Select sections to remove (avoid first and last)
207-
num_to_remove = min(self.num_sections, img.shape[0] - 2)
232+
num_to_remove = min(self.num_sections, depth - 2)
208233
indices_to_remove = self.R.choice(
209-
np.arange(1, img.shape[0] - 1), size=num_to_remove, replace=False
234+
np.arange(1, depth - 1), size=num_to_remove, replace=False
210235
)
211236

212237
if is_tensor:
213-
# Keep sections that are NOT in indices_to_remove
214-
keep_mask = torch.ones(img.shape[0], dtype=torch.bool, device=img.device)
238+
keep_mask = torch.ones(depth, dtype=torch.bool, device=img.device)
215239
keep_mask[indices_to_remove] = False
216-
return img[keep_mask]
240+
return torch.index_select(img, dim=depth_axis, index=keep_mask.nonzero(as_tuple=False).squeeze(-1))
217241
else:
218-
return np.delete(img, indices_to_remove, axis=0)
242+
return np.delete(img, indices_to_remove, axis=depth_axis)
219243

220244

221245
class RandMissingPartsd(RandomizableTransform, MapTransform):

connectomics/data/io/io.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def read_image_as_volume(filename: str, drop_channel: bool = False) -> np.ndarra
171171
Raises:
172172
ValueError: If file format is not supported
173173
"""
174-
image_suffix = filename[filename.rfind(".") + 1:].lower()
174+
image_suffix = filename[filename.rfind(".") + 1 :].lower()
175175
if image_suffix not in SUPPORTED_IMAGE_FORMATS:
176176
raise ValueError(
177177
f"Unsupported format: {image_suffix}. Supported formats: {SUPPORTED_IMAGE_FORMATS}"
@@ -281,7 +281,7 @@ def read_volume(
281281
if filename.endswith(".nii.gz"):
282282
image_suffix = "nii.gz"
283283
else:
284-
image_suffix = filename[filename.rfind(".") + 1:].lower()
284+
image_suffix = filename[filename.rfind(".") + 1 :].lower()
285285

286286
if image_suffix in ["h5", "hdf5"]:
287287
data = read_hdf5(filename, dataset)
@@ -420,7 +420,7 @@ def get_vol_shape(filename: str, dataset: Optional[str] = None) -> tuple:
420420
if filename.endswith(".nii.gz"):
421421
image_suffix = "nii.gz"
422422
else:
423-
image_suffix = filename[filename.rfind(".") + 1:].lower()
423+
image_suffix = filename[filename.rfind(".") + 1 :].lower()
424424

425425
if image_suffix in ["h5", "hdf5"]:
426426
# HDF5: Read shape from metadata (no data loading)

connectomics/data/io/tiles.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,22 +160,22 @@ def reconstruct_volume_from_tiles(
160160
if is_image: # Image data
161161
result[
162162
z - z0,
163-
y_actual_start - y0: y_actual_end - y0,
164-
x_actual_start - x0: x_actual_end - x0,
163+
y_actual_start - y0 : y_actual_end - y0,
164+
x_actual_start - x0 : x_actual_end - x0,
165165
] = patch[
166-
y_actual_start - y_patch_start: y_actual_end - y_patch_start,
167-
x_actual_start - x_patch_start: x_actual_end - x_patch_start,
166+
y_actual_start - y_patch_start : y_actual_end - y_patch_start,
167+
x_actual_start - x_patch_start : x_actual_end - x_patch_start,
168168
0,
169169
]
170170
else: # Label data
171171
result[
172172
z - z0,
173-
y_actual_start - y0: y_actual_end - y0,
174-
x_actual_start - x0: x_actual_end - x0,
173+
y_actual_start - y0 : y_actual_end - y0,
174+
x_actual_start - x0 : x_actual_end - x0,
175175
] = rgb_to_seg(
176176
patch[
177-
y_actual_start - y_patch_start: y_actual_end - y_patch_start,
178-
x_actual_start - x_patch_start: x_actual_end - x_patch_start,
177+
y_actual_start - y_patch_start : y_actual_end - y_patch_start,
178+
x_actual_start - x_patch_start : x_actual_end - x_patch_start,
179179
]
180180
)
181181

connectomics/data/process/crop.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ def crop_volume(data, sz, st=(0, 0, 0)):
1616
st = np.array(st).astype(np.int32)
1717

1818
if data.ndim == 3:
19-
return data[st[0]: st[0] + sz[0], st[1]: st[1] + sz[1], st[2]: st[2] + sz[2]]
19+
return data[st[0] : st[0] + sz[0], st[1] : st[1] + sz[1], st[2] : st[2] + sz[2]]
2020
else: # crop spatial dimensions
21-
return data[:, st[0]: st[0] + sz[0], st[1]: st[1] + sz[1], st[2]: st[2] + sz[2]]
21+
return data[:, st[0] : st[0] + sz[0], st[1] : st[1] + sz[1], st[2] : st[2] + sz[2]]
2222

2323

2424
def get_valid_pos_torch(mask, vol_sz, valid_ratio):
@@ -64,9 +64,9 @@ def get_valid_pos(mask, vol_sz, valid_ratio):
6464
if len(vol_sz) == 3:
6565
mask_sum = (
6666
mask_sum[
67-
pad_sz_pre[0]: pad_sz_post[0],
68-
pad_sz_pre[1]: pad_sz_post[1],
69-
pad_sz_pre[2]: pad_sz_post[2],
67+
pad_sz_pre[0] : pad_sz_post[0],
68+
pad_sz_pre[1] : pad_sz_post[1],
69+
pad_sz_pre[2] : pad_sz_post[2],
7070
]
7171
>= valid_thres
7272
)
@@ -86,7 +86,7 @@ def get_valid_pos(mask, vol_sz, valid_ratio):
8686
)
8787
else:
8888
mask_sum = (
89-
mask_sum[pad_sz_pre[0]: pad_sz_post[0], pad_sz_pre[1]: pad_sz_post[1]] >= valid_thres
89+
mask_sum[pad_sz_pre[0] : pad_sz_post[0], pad_sz_pre[1] : pad_sz_post[1]] >= valid_thres
9090
)
9191
if mask_sum.max() > 0:
9292
yy, xx = np.meshgrid(np.arange(mask_sum.shape[0]), np.arange(mask_sum.shape[1]))

0 commit comments

Comments
 (0)