You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
* Add device utility with MPS support and improve device handling
This commit adds comprehensive device detection and selection utilities
that support CUDA, MPS (Apple Silicon), and CPU backends with automatic
fallback logic.
Changes:
- Add decoder_pytorch/device.py with DeviceSelection dataclass and
get_optimal_device() function
- Update decoder_pytorch/__init__.py to export new device utilities
- Refactor train.py to use get_optimal_device() instead of hardcoded
device selection
- Use device-specific autocast dtype (bfloat16 for CUDA/MPS, float32
for CPU)
- Integrate TF32 configuration into get_optimal_device for CUDA
- Update fused optimizer check to only enable on CUDA (not MPS/CPU)
The get_optimal_device() function provides:
- Automatic device detection with configurable priority order
- Force device selection via parameter or FORCE_DEVICE env var
- Integrated TF32 configuration for CUDA devices
- Appropriate autocast dtype selection per device type
- Detailed device info logging
This ensures the codebase works seamlessly across CUDA, MPS, and CPU
devices with optimal settings for each platform.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <noreply@anthropic.com>
* Fix CPU autocast to use bfloat16 instead of float32
Changed CPU device selection to use torch.bfloat16 for autocast,
consistent with the repo's assumption of bfloat16-compatible hardware
(2025 AD standard).
This eliminates the warning:
"CPU Autocast only supports dtype of torch.bfloat16, torch.float16"
All devices (CUDA, MPS, CPU) now uniformly use bfloat16 for mixed
precision training.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <noreply@anthropic.com>
* Add use_autocast config option for mixed precision control
Added configurable autocast support to allow users to enable/disable
mixed precision training via config files without modifying code.
Changes:
- Add use_autocast config option to simple.yaml and test.yaml (default: true)
- Update train.py to conditionally use autocast based on config
- Use contextlib.nullcontext() when autocast is disabled
- Print mixed precision status on startup
Usage:
use_autocast: true # Enable bfloat16 mixed precision (default)
use_autocast: false # Disable, use full fp32 precision
Both configurations tested successfully with no warnings.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <noreply@anthropic.com>
* ✨ Ergonomic accelerator detection, safer selection, AMP hook, and docs
* Add `DeviceSelection.autocast_context()`; parse `cuda:N`, dedupe prefs, and warn on bad input (decoder_pytorch/device.py:26,53).
* Honor forced indices; guard out-of-range CUDA; loud CPU fallback for debug (decoder_pytorch/device.py:107).
* Use context helper for train/val; fix E731; AMP toggle driven by config (train.py:74).
* Document detection flow, `FORCE_DEVICE`, and autocast usage (README.md:27).
* clarify, format
* Simplify device handling with tuple API and improve training stability
- Replace device.py with lightweight tuple-based API and auto-fallbacks
- Centralize device checks in training; guard autocast and document grad quirks
- Graceful Flash Attention degradation when kernels unavailable
- Add nano.yaml config for quick CPU/MPS testing
- Update docs to reflect new device API and config
* Enforce autocast and fix nano config alignment
- Stop silently disabling autocast; always respect use_autocast flag
- Wrap autocast context manager on all devices (no silent fp32 fallback)
- Align nano.yaml to ~20M Llama with bf16 autocast enabled
- Clarify autocast behavior in README
* Set nano preset to L6·H384 (~9M) and update docs
- Update nano.yaml: depth 6, dim 384, torch.compile on
- Clarify model scale in README
---------
Co-authored-by: Claude <noreply@anthropic.com>
Copy file name to clipboardExpand all lines: README.md
+22-4Lines changed: 22 additions & 4 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -20,19 +20,37 @@ Train on the [included enwik8 dataset](data/README.md), character-level modeling
20
20
# 100k batches on enwik8, 35M param Llama
21
21
python train.py --config configs/simple.yaml
22
22
23
-
# Quick test (tiny model, 10 batches)
23
+
# Nano run for CPU / MPS shakedowns (10k steps, L6 · H384 · ~9M params)
24
+
python train.py --config configs/nano.yaml
25
+
26
+
# Quick smoke test (tiny model, 10 batches)
24
27
python train.py --config configs/test.yaml
25
28
```
26
29
30
+
## Device Selection & Precision
31
+
32
+
- The training script calls `decoder_pytorch.get_optimal_device()` which prefers `cuda → mps → cpu`, returning `(device, device_type, amp_dtype)` and printing the accelerator picked.
33
+
- Override detection with `FORCE_DEVICE=cuda`, `FORCE_DEVICE=cpu`, or even `FORCE_DEVICE=cuda:1` to pick a specific index (also available as the `force=` argument).
34
+
- Mixed precision uses `torch.autocast` with `torch.bfloat16`; toggle via config if you want full fp32.
[^2]: If using PyTorch <2.9, you'll need to modify the TF32 configuration in `decoder_pytorch/utils.py` to use the legacy API (`torch.set_float32_matmul_precision("high")`) or skip TF32 setup entirely.
134
+
[^2]: If using PyTorch <2.9, you may need to adjust the bfloat16/autocast behaviour or fall back to full fp32 depending on hardware support.
0 commit comments