Skip to content

Commit 9c90a55

Browse files
pszemrajclaude
andauthored
support MPS, reorganize (#1)
* Add device utility with MPS support and improve device handling This commit adds comprehensive device detection and selection utilities that support CUDA, MPS (Apple Silicon), and CPU backends with automatic fallback logic. Changes: - Add decoder_pytorch/device.py with DeviceSelection dataclass and get_optimal_device() function - Update decoder_pytorch/__init__.py to export new device utilities - Refactor train.py to use get_optimal_device() instead of hardcoded device selection - Use device-specific autocast dtype (bfloat16 for CUDA/MPS, float32 for CPU) - Integrate TF32 configuration into get_optimal_device for CUDA - Update fused optimizer check to only enable on CUDA (not MPS/CPU) The get_optimal_device() function provides: - Automatic device detection with configurable priority order - Force device selection via parameter or FORCE_DEVICE env var - Integrated TF32 configuration for CUDA devices - Appropriate autocast dtype selection per device type - Detailed device info logging This ensures the codebase works seamlessly across CUDA, MPS, and CPU devices with optimal settings for each platform. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Fix CPU autocast to use bfloat16 instead of float32 Changed CPU device selection to use torch.bfloat16 for autocast, consistent with the repo's assumption of bfloat16-compatible hardware (2025 AD standard). This eliminates the warning: "CPU Autocast only supports dtype of torch.bfloat16, torch.float16" All devices (CUDA, MPS, CPU) now uniformly use bfloat16 for mixed precision training. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Add use_autocast config option for mixed precision control Added configurable autocast support to allow users to enable/disable mixed precision training via config files without modifying code. Changes: - Add use_autocast config option to simple.yaml and test.yaml (default: true) - Update train.py to conditionally use autocast based on config - Use contextlib.nullcontext() when autocast is disabled - Print mixed precision status on startup Usage: use_autocast: true # Enable bfloat16 mixed precision (default) use_autocast: false # Disable, use full fp32 precision Both configurations tested successfully with no warnings. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * ✨ Ergonomic accelerator detection, safer selection, AMP hook, and docs * Add `DeviceSelection.autocast_context()`; parse `cuda:N`, dedupe prefs, and warn on bad input (decoder_pytorch/device.py:26,53). * Honor forced indices; guard out-of-range CUDA; loud CPU fallback for debug (decoder_pytorch/device.py:107). * Use context helper for train/val; fix E731; AMP toggle driven by config (train.py:74). * Document detection flow, `FORCE_DEVICE`, and autocast usage (README.md:27). * clarify, format * Simplify device handling with tuple API and improve training stability - Replace device.py with lightweight tuple-based API and auto-fallbacks - Centralize device checks in training; guard autocast and document grad quirks - Graceful Flash Attention degradation when kernels unavailable - Add nano.yaml config for quick CPU/MPS testing - Update docs to reflect new device API and config * Enforce autocast and fix nano config alignment - Stop silently disabling autocast; always respect use_autocast flag - Wrap autocast context manager on all devices (no silent fp32 fallback) - Align nano.yaml to ~20M Llama with bf16 autocast enabled - Clarify autocast behavior in README * Set nano preset to L6·H384 (~9M) and update docs - Update nano.yaml: depth 6, dim 384, torch.compile on - Clarify model scale in README --------- Co-authored-by: Claude <noreply@anthropic.com>
1 parent c4fb075 commit 9c90a55

File tree

8 files changed

+264
-49
lines changed

8 files changed

+264
-49
lines changed

README.md

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,37 @@ Train on the [included enwik8 dataset](data/README.md), character-level modeling
2020
# 100k batches on enwik8, 35M param Llama
2121
python train.py --config configs/simple.yaml
2222

23-
# Quick test (tiny model, 10 batches)
23+
# Nano run for CPU / MPS shakedowns (10k steps, L6 · H384 · ~9M params)
24+
python train.py --config configs/nano.yaml
25+
26+
# Quick smoke test (tiny model, 10 batches)
2427
python train.py --config configs/test.yaml
2528
```
2629

30+
## Device Selection & Precision
31+
32+
- The training script calls `decoder_pytorch.get_optimal_device()` which prefers `cuda → mps → cpu`, returning `(device, device_type, amp_dtype)` and printing the accelerator picked.
33+
- Override detection with `FORCE_DEVICE=cuda`, `FORCE_DEVICE=cpu`, or even `FORCE_DEVICE=cuda:1` to pick a specific index (also available as the `force=` argument).
34+
- Mixed precision uses `torch.autocast` with `torch.bfloat16`; toggle via config if you want full fp32.
35+
36+
## Device Support
37+
38+
| Device | Status | Notes |
39+
| ------------- | ------ | --------------------------------------------------- |
40+
| NVIDIA GPU || Best performance, fused optimizer & flash attention |
41+
| Apple Silicon || Good performance, autocast can be flaky |
42+
| CPU || Slow but works; use `configs/nano.yaml` |
43+
2744
## Structure
2845

2946
```text
3047
decoder-pytorch-template/
3148
├── decoder_pytorch/ # Model implementation
3249
│ ├── llama.py # Llama architecture
33-
│ └── utils.py # Sampling utilities
50+
│ └── utils.py # Sampling & device helpers
3451
├── configs/ # Training configs
3552
│ ├── simple.yaml # Default config
53+
│ ├── nano.yaml # Quick CPU/MPS config
3654
│ └── test.yaml # Quick test config
3755
├── data/
3856
│ └── enwik8.gz # Character-level dataset
@@ -61,7 +79,7 @@ To add your own model architecture:
6179
4. **Update training script**: Modify `train.py` line 16 and 88:
6280

6381
```python
64-
from decoder_pytorch import YourModel, configure_tf32, model_summary
82+
from decoder_pytorch import YourModel, model_summary
6583
# ...
6684
model = YourModel(
6785
num_tokens=config.get("num_tokens", 256),
@@ -113,7 +131,7 @@ Dependencies:
113131
- einops, pyyaml, tqdm
114132
- [rotary-embedding-torch](https://github.com/lucidrains/rotary-embedding-torch)
115133
116-
[^2]: If using PyTorch <2.9, you'll need to modify the TF32 configuration in `decoder_pytorch/utils.py` to use the legacy API (`torch.set_float32_matmul_precision("high")`) or skip TF32 setup entirely.
134+
[^2]: If using PyTorch <2.9, you may need to adjust the bfloat16/autocast behaviour or fall back to full fp32 depending on hardware support.
117135
118136
## License
119137

configs/nano.yaml

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Smaller than simple, larger than test (L6 · H384 · ~9M params)
2+
run_dir: runs/nano
3+
4+
# Model
5+
num_tokens: 256
6+
dim: 384
7+
depth: 6
8+
heads: 6
9+
dim_head: 64
10+
tied_embedding: true
11+
ffn_dim_multiplier: 1.5
12+
flash_attn: true
13+
compile: true # turn off if torch.compile is unavailable
14+
use_autocast: true
15+
16+
# Training schedule
17+
num_batches: 10000
18+
batch_size: 1
19+
grad_accum_every: 16
20+
learning_rate: 0.002
21+
weight_decay: 0.0003
22+
grad_clip_norm: 1.0
23+
24+
# Data
25+
data_path: data/enwik8.gz
26+
seq_len: 512
27+
28+
# training/validation/generation
29+
validate_every: 250
30+
val_batches: 20
31+
generate_every: 250
32+
save_every: 2000
33+
temperature: 1.0
34+
min_p: 0.1
35+
36+
seed: 7

configs/simple.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ tied_embedding: true # share in/out embeddings
1212
ffn_dim_multiplier: 1.5 # hidden dim multiplier --> FFN size (here, 768)
1313
flash_attn: true # use flash attn (through torch api)
1414
compile: false # speed up training
15+
use_autocast: true # enable mixed precision (bfloat16)
1516

1617
# Training
1718
num_batches: 100000 # total steps

configs/test.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ dim_head: 32
1010
tied_embedding: true
1111
flash_attn: true
1212
compile: false
13+
use_autocast: true
1314

1415
# Quick training
1516
num_batches: 10
@@ -28,5 +29,4 @@ val_batches: 50
2829
generate_every: 1000 # Don't generate during test
2930
save_every: 1000 # Don't save during test
3031

31-
# Random seed for reproducibility
32-
seed: 42
32+
seed: 7

decoder_pytorch/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from .llama import Llama
44
from .utils import (
5-
configure_tf32,
5+
get_optimal_device,
66
gumbel_noise,
77
gumbel_sample,
88
log,
@@ -22,6 +22,6 @@
2222
"top_k_filter",
2323
"top_p_filter",
2424
# Torch utilities
25-
"configure_tf32",
2625
"model_summary",
26+
"get_optimal_device",
2727
]

decoder_pytorch/llama.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,11 @@ def __init__(
140140
self.dim_head = dim_head
141141
self.causal = causal
142142
self.flash_attn = flash_attn
143+
if self.flash_attn and not self._flash_attn_available():
144+
print(
145+
"Warning: Flash attention requested but not available, using standard attention."
146+
)
147+
self.flash_attn = False
143148
self.max_seq_len = max_seq_len
144149

145150
inner_dim = heads * dim_head
@@ -158,6 +163,20 @@ def __init__(
158163
# Register causal mask buffer (will be created on first use)
159164
self.register_buffer("causal_mask", None, persistent=False)
160165

166+
def _flash_attn_available(self) -> bool:
167+
"""Return True if flash attention kernels are available."""
168+
if torch.cuda.is_available():
169+
return True
170+
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
171+
return True
172+
dynamo = getattr(torch, "_dynamo", None)
173+
if dynamo is not None:
174+
try:
175+
return bool(dynamo.is_compiling())
176+
except Exception:
177+
return False
178+
return False
179+
161180
def forward(
162181
self,
163182
x: torch.Tensor,

decoder_pytorch/utils.py

Lines changed: 141 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from dataclasses import dataclass
23
from typing import Dict, List, Optional, Set, Tuple
34

@@ -108,36 +109,85 @@ def top_p_filter(logits: Tensor, p: float = 0.9) -> Tensor:
108109
# --------------------------------------------------------------------------
109110

110111

111-
def configure_tf32() -> bool:
112-
"""Enable TF32 precision for GPUs with compute capability >= 8.0 (Ampere+).
112+
def _mps_available() -> bool:
113+
"""Return True if MPS is available."""
114+
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
113115

114-
Uses the PyTorch 2.9+ API for TF32 configuration.
115116

116-
:return: True if TF32 was enabled, False otherwise
117+
def get_optimal_device(
118+
force: Optional[str] = None,
119+
) -> Tuple[torch.device, str, torch.dtype]:
120+
"""Return best available accelerator (device, device_type, amp_dtype).
121+
122+
The function tries CUDA → MPS → CPU, unless the user forces a choice via
123+
the ``force`` argument or the ``FORCE_DEVICE`` environment variable. The
124+
return value is intentionally simple—a tuple that works well with tuple
125+
unpacking in training scripts.
117126
"""
118-
if not torch.cuda.is_available():
119-
print("No GPU detected, running on CPU.")
120-
return False
121-
122-
try:
123-
device = torch.cuda.current_device()
124-
capability = torch.cuda.get_device_capability(device)
125-
major, minor = capability
126-
gpu_name = torch.cuda.get_device_name(device)
127-
128-
if major >= 8:
129-
# PyTorch 2.9+ API for TF32 configuration
130-
torch.backends.cudnn.conv.fp32_precision = "tf32"
131-
torch.backends.cuda.matmul.fp32_precision = "tf32"
132-
print(f"{gpu_name} (compute {major}.{minor}) - TF32 enabled")
133-
return True
134-
else:
135-
print(f"{gpu_name} (compute {major}.{minor}) - TF32 not supported")
136-
return False
137127

138-
except Exception as e:
139-
print(f"Error: failed to configure GPU: {e}")
140-
return False
128+
def _normalize(device_str: str) -> str:
129+
return device_str.split(":", 1)[0]
130+
131+
requested = (force or os.getenv("FORCE_DEVICE", "")).strip().lower()
132+
valid_types = {"cuda", "mps", "cpu"}
133+
134+
if requested:
135+
requested_type = _normalize(requested)
136+
if requested_type not in valid_types:
137+
print(
138+
f"Warning: unsupported FORCE_DEVICE='{requested}'. "
139+
"Falling back to auto-detect."
140+
)
141+
requested = ""
142+
elif requested_type == "cuda" and not torch.cuda.is_available():
143+
print("Warning: CUDA requested but not available, falling back.")
144+
requested = ""
145+
elif requested_type == "mps" and not _mps_available():
146+
print("Warning: MPS requested but not available, falling back.")
147+
requested = ""
148+
149+
if requested:
150+
try:
151+
device = torch.device(requested)
152+
except (RuntimeError, ValueError) as err:
153+
print(f"Warning: could not create device '{requested}' ({err}).")
154+
requested = ""
155+
else:
156+
device_type = _normalize(requested)
157+
if device_type == "cuda":
158+
index = device.index or 0
159+
device_count = torch.cuda.device_count()
160+
if index >= device_count:
161+
print(
162+
f"Warning: CUDA index {index} unavailable "
163+
f"(found {device_count} device(s)). Falling back."
164+
)
165+
requested = ""
166+
else:
167+
name = torch.cuda.get_device_name(index)
168+
print(f"Using CUDA device {index}: {name}")
169+
return device, "cuda", torch.bfloat16
170+
elif device_type == "mps":
171+
print("Using Apple Silicon (MPS)")
172+
return device, "mps", torch.bfloat16
173+
else:
174+
print("Using CPU (forced)")
175+
return device, "cpu", torch.bfloat16
176+
177+
if torch.cuda.is_available():
178+
device = torch.device("cuda")
179+
name = torch.cuda.get_device_name(0)
180+
print(f"Using CUDA: {name}")
181+
return device, "cuda", torch.bfloat16
182+
183+
if _mps_available():
184+
device = torch.device("mps")
185+
print("Using Apple Silicon (MPS)")
186+
return device, "mps", torch.bfloat16
187+
188+
device = torch.device("cpu")
189+
print("Using CPU (no GPU acceleration available)")
190+
return device, "cpu", torch.bfloat16
141191

142192

143193
@dataclass
@@ -151,13 +201,19 @@ class _LayerSummary:
151201

152202

153203
def model_summary(
154-
model: nn.Module, max_depth: int = 4, show_param_shapes: bool = False
204+
model: nn.Module,
205+
max_depth: int = 4,
206+
show_param_shapes: bool = False,
207+
show_frozen_breakdown: bool = False,
155208
) -> None:
156209
"""Print hierarchical summary of model with parameter counts.
157210
158211
:param model: PyTorch model to summarize
159212
:param max_depth: Maximum depth of hierarchy to display
160213
:param show_param_shapes: Whether to show parameter shapes
214+
:param show_frozen_breakdown: If True, display separate trainable/frozen counts
215+
per module. Defaults to False for a simpler view that highlights whether
216+
a module is fully trainable, fully frozen, or mixed.
161217
"""
162218

163219
# ---------- formatting helpers ----------
@@ -248,15 +304,60 @@ def summarize_recursive(module: nn.Module, depth: int, prefix: str) -> Set[int]:
248304
max(len(_format_shape(s.param_shape)) for s in summary_list),
249305
)
250306

251-
params_col_width = 12
252-
trainable_col_width = 10
253-
col_spacing = " "
307+
params_col_width = max(
308+
len("Param #"),
309+
max(len(_format_number(s.inclusive_total_params)) for s in summary_list),
310+
)
254311

255312
header_parts = [f"{'Layer (type)':<{name_col_width}}"]
256313
if show_param_shapes:
257314
header_parts.append(f"{'Param Shape':>{shape_col_width}}")
315+
258316
header_parts.append(f"{'Param #':>{params_col_width}}")
259-
header_parts.append(f"{'Trainable':>{trainable_col_width}}")
317+
318+
if show_frozen_breakdown:
319+
trainable_col_width = max(
320+
len("Trainable #"),
321+
max(
322+
len(_format_number(s.inclusive_trainable_params)) for s in summary_list
323+
),
324+
)
325+
frozen_col_width = max(
326+
len("Frozen #"),
327+
max(
328+
len(
329+
_format_number(
330+
s.inclusive_total_params - s.inclusive_trainable_params
331+
)
332+
)
333+
for s in summary_list
334+
),
335+
)
336+
header_parts.append(f"{'Trainable #':>{trainable_col_width}}")
337+
header_parts.append(f"{'Frozen #':>{frozen_col_width}}")
338+
else:
339+
340+
def _grad_state(total: int, trainable: int) -> str:
341+
if trainable == 0:
342+
return "frozen"
343+
if trainable == total:
344+
return "trainable"
345+
return "mixed"
346+
347+
grad_states = [
348+
_grad_state(
349+
s.inclusive_total_params,
350+
s.inclusive_trainable_params,
351+
)
352+
for s in summary_list
353+
]
354+
grad_state_width = max(
355+
len("Grad State"), max(len(state) for state in grad_states)
356+
)
357+
header_parts.append(f"{'Grad State':>{grad_state_width}}")
358+
359+
col_spacing = " "
360+
260361
header = col_spacing.join(header_parts)
261362
sep = "=" * len(header)
262363

@@ -268,7 +369,15 @@ def summarize_recursive(module: nn.Module, depth: int, prefix: str) -> Set[int]:
268369
if show_param_shapes:
269370
parts.append(f"{_format_shape(e.param_shape):>{shape_col_width}}")
270371
parts.append(f"{_format_number(e.inclusive_total_params):>{params_col_width}}")
271-
parts.append(f"{str(e.inclusive_trainable_params > 0):>{trainable_col_width}}")
372+
if show_frozen_breakdown:
373+
parts.append(
374+
f"{_format_number(e.inclusive_trainable_params):>{trainable_col_width}}"
375+
)
376+
frozen = e.inclusive_total_params - e.inclusive_trainable_params
377+
parts.append(f"{_format_number(frozen):>{frozen_col_width}}")
378+
else:
379+
state = _grad_state(e.inclusive_total_params, e.inclusive_trainable_params)
380+
parts.append(f"{state:>{grad_state_width}}")
272381
print(col_spacing.join(parts))
273382
print(sep)
274383
print(f"Total params: {_format_number(total_params)}")

0 commit comments

Comments
 (0)