@@ -132,6 +132,7 @@ python = "~=3.13.0"
132132numpy = " =1.22.0"
133133
134134# Backends that can run on CPU-only hosts
135+ # Note: JAX and PyTorch will install CPU variants.
135136[tool .pixi .feature .backends .dependencies ]
136137pytorch = " *"
137138dask = " *"
@@ -153,25 +154,34 @@ jax = "*"
153154[tool .pixi .feature .backends .target .win-64 .dependencies ]
154155# jax = "*" # unavailable
155156
156- # Backends that require a GPU host and a CUDA driver
157+ # Backends that require a GPU host and a CUDA driver.
158+ # Note that JAX and PyTorch automatically prefer CUDA variants
159+ # thanks to the `system-requirements` below, *if available*.
160+ # We request them explicitly below to ensure that we don't
161+ # quietly revert to CPU-only in the future, e.g. when CUDA 13
162+ # is released and CUDA 12 builds are dropped upstream.
157163[tool .pixi .feature .cuda-backends ]
158164system-requirements = { cuda = " 12" }
159165
160166[tool .pixi .feature .cuda-backends .target .linux-64 .dependencies ]
161167cupy = " *"
162168jaxlib = { version = " *" , build = " cuda12*" }
169+ pytorch = { version = " *" , build = " cuda12*" }
163170
164171[tool .pixi .feature .cuda-backends .target .osx-64 .dependencies ]
165172# cupy = "*" # unavailable
166173# jaxlib = { version = "*", build = "cuda12*" } # unavailable
174+ # pytorch = { version = "*", build = "cuda12*" } # unavailable
167175
168176[tool .pixi .feature .cuda-backends .target .osx-arm64 .dependencies ]
169177# cupy = "*" # unavailable
170178# jaxlib = { version = "*", build = "cuda12*" } # unavailable
179+ # pytorch = { version = "*", build = "cuda12*" } # unavailable
171180
172181[tool .pixi .feature .cuda-backends .target .win-64 .dependencies ]
173182cupy = " *"
174183# jaxlib = { version = "*", build = "cuda12*" } # unavailable
184+ pytorch = { version = " *" , build = " cuda12*" }
175185
176186[tool .pixi .environments ]
177187default = { features = [" py313" ], solve-group = " py313" }
0 commit comments