Skip to content

Conversation

@AntonOresten
Copy link

@AntonOresten AntonOresten commented Nov 28, 2025

This PR defines methods for making cuDNN work with BFloat16s.BFloat16.

In the following example, I show how the new methods fixes the BFloat16 backward pass of Flux.logitcrossentropy:

Before

Note: Core.BFloat16 === BFloat16s.BFloat16, but I didn't explicitly import in this REPL session.

julia> x, y = CUDA.randn(Core.BFloat16, 32), CUDA.randn(Core.BFloat16, 32); Flux.gradient(x) do x
           Flux.logitcrossentropy(x, y)
       end
ERROR: MethodError: no method matching cudnnDataType(::Type{Core.BFloat16})
The function `cudnnDataType` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  cudnnDataType(::Type{Float16})
   @ cuDNN ~/.julia/packages/cuDNN/vKsqU/src/util.jl:7
  cudnnDataType(::Type{Float32})
   @ cuDNN ~/.julia/packages/cuDNN/vKsqU/src/util.jl:8
  cudnnDataType(::Type{Float64})
   @ cuDNN ~/.julia/packages/cuDNN/vKsqU/src/util.jl:9
  ...

Stacktrace:
  [1] cuDNN.cudnnTensorDescriptor(array::CuArray{Core.BFloat16, 4, CUDA.DeviceMemory}; format::cuDNN.cudnnTensorFormat_t, dims::Vector{Int32})
    @ cuDNN ~/.julia/packages/cuDNN/vKsqU/src/tensor.jl:9
  [2] cudnnSoftmaxForward!(y::CuArray{…}, x::CuArray{…}; o::@Kwargs{})
    @ cuDNN ~/.julia/packages/cuDNN/vKsqU/src/softmax.jl:17
  [3] logsoftmax!(y::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory}, x::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory}; dims::Int64)
    @ NNlibCUDACUDNNExt ~/.julia/packages/NNlib/1TYHL/ext/NNlibCUDACUDNNExt/softmax.jl:90
  [4] logsoftmax!
    @ ~/.julia/packages/NNlib/1TYHL/ext/NNlibCUDACUDNNExt/softmax.jl:87 [inlined]
  [5] #logsoftmax#41
    @ ~/.julia/packages/NNlib/1TYHL/ext/NNlibCUDACUDNNExt/softmax.jl:20 [inlined]
  [6] logsoftmax
    @ ~/.julia/packages/NNlib/1TYHL/ext/NNlibCUDACUDNNExt/softmax.jl:19 [inlined]
  [7] #rrule#109
    @ ~/.julia/packages/NNlib/1TYHL/src/softmax.jl:129 [inlined]
  [8] rrule
    @ ~/.julia/packages/NNlib/1TYHL/src/softmax.jl:128 [inlined]
  [9] rrule
    @ ~/.julia/packages/ChainRulesCore/Vsbj9/src/rules.jl:144 [inlined]
 [10] chain_rrule_kw
    @ ~/.julia/packages/Zygote/55SqB/src/compiler/chainrules.jl:246 [inlined]
 [11] macro expansion
    @ ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:-1 [inlined]
 [12] _pullback
    @ ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:81 [inlined]
 [13] #logitcrossentropy#20
    @ ~/.julia/packages/Flux/uRn8o/src/losses/functions.jl:272 [inlined]
 [14] _pullback(::Zygote.Context{…}, ::Flux.Losses.var"##logitcrossentropy#20", ::Int64, ::typeof(mean), ::typeof(Flux.Losses.logitcrossentropy), ::CuArray{…}, ::CuArray{…})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [15] _pullback(::Zygote.Context{…}, ::typeof(Flux.Losses.logitcrossentropy), ::CuArray{…}, ::CuArray{…})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:81
 [16] #8
    @ ./REPL[14]:2 [inlined]
 [17] _pullback(ctx::Zygote.Context{false}, f::var"#8#9", args::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [18] pullback(f::Function, cx::Zygote.Context{false}, args::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:96
 [19] pullback
    @ ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:94 [inlined]
 [20] gradient(f::Function, args::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:153
 [21] #gradient#1
    @ ~/.julia/packages/Flux/uRn8o/src/gradient.jl:44 [inlined]
 [22] gradient(f::Function, args::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory})
    @ Flux ~/.julia/packages/Flux/uRn8o/src/gradient.jl:31
 [23] top-level scope
    @ REPL[14]:1
Some type information was truncated. Use `show(err)` to see complete types.
After defining cudnnDataType(::Type{BFloat16})
julia> x, y = CUDA.randn(Core.BFloat16, 32), CUDA.randn(Core.BFloat16, 32); Flux.gradient(x) do x
           Flux.logitcrossentropy(x, y)
       end
ERROR: Unknown tensor type Core.BFloat16
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:44
  [2] scalingParameter(T::Type, val::Int64)
    @ cuDNN ~/.julia/packages/cuDNN/vKsqU/src/util.jl:34
  [3] cudnnSoftmaxForwardWithDefaults(x::CuArray{…}; y::CuArray{…}, algo::cuDNN.cudnnSoftmaxAlgorithm_t, mode::cuDNN.cudnnSoftmaxMode_t, alpha::Int64, beta::Int64, format::cuDNN.cudnnTensorFormat_t, xDesc::cuDNN.cudnnTensorDescriptor, yDesc::cuDNN.cudnnTensorDescriptor)
    @ cuDNN ~/.julia/packages/cuDNN/vKsqU/src/softmax.jl:34
  [4] cudnnSoftmaxForward!(y::CuArray{…}, x::CuArray{…}; o::@Kwargs{})
    @ cuDNN ~/.julia/packages/cuDNN/vKsqU/src/softmax.jl:17
  [5] logsoftmax!(y::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory}, x::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory}; dims::Int64)
    @ NNlibCUDACUDNNExt ~/.julia/packages/NNlib/1TYHL/ext/NNlibCUDACUDNNExt/softmax.jl:90
  [6] logsoftmax!
    @ ~/.julia/packages/NNlib/1TYHL/ext/NNlibCUDACUDNNExt/softmax.jl:87 [inlined]
  [7] #logsoftmax#41
    @ ~/.julia/packages/NNlib/1TYHL/ext/NNlibCUDACUDNNExt/softmax.jl:20 [inlined]
  [8] logsoftmax
    @ ~/.julia/packages/NNlib/1TYHL/ext/NNlibCUDACUDNNExt/softmax.jl:19 [inlined]
  [9] #rrule#109
    @ ~/.julia/packages/NNlib/1TYHL/src/softmax.jl:129 [inlined]
 [10] rrule
    @ ~/.julia/packages/NNlib/1TYHL/src/softmax.jl:128 [inlined]
 [11] rrule
    @ ~/.julia/packages/ChainRulesCore/Vsbj9/src/rules.jl:144 [inlined]
 [12] chain_rrule_kw
    @ ~/.julia/packages/Zygote/55SqB/src/compiler/chainrules.jl:246 [inlined]
 [13] macro expansion
    @ ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:-1 [inlined]
 [14] _pullback
    @ ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:81 [inlined]
 [15] #logitcrossentropy#20
    @ ~/.julia/packages/Flux/uRn8o/src/losses/functions.jl:272 [inlined]
 [16] _pullback(::Zygote.Context{…}, ::Flux.Losses.var"##logitcrossentropy#20", ::Int64, ::typeof(mean), ::typeof(Flux.Losses.logitcrossentropy), ::CuArray{…}, ::CuArray{…})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [17] _pullback(::Zygote.Context{false}, ::typeof(Flux.Losses.logitcrossentropy), ::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory}, ::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:81
 [18] #11
    @ ./REPL[19]:2 [inlined]
 [19] _pullback(ctx::Zygote.Context{false}, f::var"#11#12", args::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [20] pullback(f::Function, cx::Zygote.Context{false}, args::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:96
 [21] pullback
    @ ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:94 [inlined]
 [22] gradient(f::Function, args::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:153
 [23] #gradient#1
    @ ~/.julia/packages/Flux/uRn8o/src/gradient.jl:44 [inlined]
 [24] gradient(f::Function, args::CuArray{Core.BFloat16, 1, CUDA.DeviceMemory})
    @ Flux ~/.julia/packages/Flux/uRn8o/src/gradient.jl:31
 [25] top-level scope
    @ REPL[19]:1
Some type information was truncated. Use `show(err)` to see complete types.
After defining scalingParameter(::Type{BFloat16}, val)
julia> x, y = CUDA.randn(Core.BFloat16, 32), CUDA.randn(Core.BFloat16, 32); Flux.gradient(x) do x
           Flux.logitcrossentropy(x, y)
       end
(Core.BFloat16[0.19335938, 0.32226562, -0.23828125, -0.85546875, 0.953125, 0.12207031, 1.15625, -0.64453125, -0.103515625, 0.61328125    0.4453125, -1.203125, 1.0234375, -1.46875, 0.19628906, -0.87890625, -1.3203125, 1.515625, 0.6484375, 0.44921875],)

I also define a cptr method for consistency, but it appears the function isn't used anywhere.

Tests are added for softmax, activations, and pooling. I initially also tested convolutions, normalization, RNNs, and MHA but they don't appear to support BFloat16.

Adding BFloat16s.jl as a dependency does not affect compilation since it's already a dependency of CUDA.jl.

Along with my proposed fix in FluxML/Optimisers.jl#215, this has allowed me to train LLMs in BFloat16 with Flux.jl in Julia v1.12. I am still working on an Optimisers.jl, but these together would be a significant unlock for my lab.

@github-actions
Copy link
Contributor

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic master) to apply these changes.

Click here to view the suggested changes.
diff --git a/lib/cudnn/src/util.jl b/lib/cudnn/src/util.jl
index 8923ff9b5..c7ec0c2bd 100644
--- a/lib/cudnn/src/util.jl
+++ b/lib/cudnn/src/util.jl
@@ -4,13 +4,13 @@ using BFloat16s: BFloat16
 cptr(x,a::DenseCuArray{Float64})=Float64[x]
 cptr(x,a::DenseCuArray{Float32})=Float32[x]
 cptr(x,a::DenseCuArray{Float16})=Float32[x]
-cptr(x,a::DenseCuArray{BFloat16})=Float32[x]
+cptr(x, a::DenseCuArray{BFloat16}) = Float32[x]
 
 # Conversion between Julia and cuDNN datatypes
 cudnnDataType(::Type{Float16})=CUDNN_DATA_HALF
 cudnnDataType(::Type{Float32})=CUDNN_DATA_FLOAT
 cudnnDataType(::Type{Float64})=CUDNN_DATA_DOUBLE
-cudnnDataType(::Type{BFloat16})=CUDNN_DATA_BFLOAT16
+cudnnDataType(::Type{BFloat16}) = CUDNN_DATA_BFLOAT16
 cudnnDataType(::Type{Int8}) = CUDNN_DATA_INT8
 cudnnDataType(::Type{UInt8}) = CUDNN_DATA_UINT8
 cudnnDataType(::Type{Int32}) = CUDNN_DATA_INT32
@@ -21,7 +21,7 @@ cudnnDataType(::Type{Int32}) = CUDNN_DATA_INT32
 juliaDataType(a)=(a==CUDNN_DATA_HALF ? Float16 :
                   a==CUDNN_DATA_FLOAT ? Float32 :
                   a==CUDNN_DATA_DOUBLE ? Float64 :
-                  a==CUDNN_DATA_BFLOAT16 ? BFloat16 :
+        a == CUDNN_DATA_BFLOAT16 ? BFloat16 :
                   a==CUDNN_DATA_INT8 ? Int8 :
                   a==CUDNN_DATA_UINT8 ? UInt8 :
                   a==CUDNN_DATA_INT32 ? Int32 : error())
diff --git a/lib/cudnn/test/activation.jl b/lib/cudnn/test/activation.jl
index 7b7f2f01a..4164e0231 100644
--- a/lib/cudnn/test/activation.jl
+++ b/lib/cudnn/test/activation.jl
@@ -62,8 +62,8 @@ activationtest(alpha=2)
 activationtest(beta=2)
 
 # BFloat16 tests
-(ax,ay) = randn.(BFloat16, (10,10))
-(cx,cy) = CuArray.((ax,ay))
-activationtest(mode=CUDNN_ACTIVATION_SIGMOID)
-activationtest(mode=CUDNN_ACTIVATION_RELU)
-activationtest(mode=CUDNN_ACTIVATION_TANH)
+(ax, ay) = randn.(BFloat16, (10, 10))
+(cx, cy) = CuArray.((ax, ay))
+activationtest(mode = CUDNN_ACTIVATION_SIGMOID)
+activationtest(mode = CUDNN_ACTIVATION_RELU)
+activationtest(mode = CUDNN_ACTIVATION_TANH)
diff --git a/lib/cudnn/test/softmax.jl b/lib/cudnn/test/softmax.jl
index 68967bc1d..ab446813c 100644
--- a/lib/cudnn/test/softmax.jl
+++ b/lib/cudnn/test/softmax.jl
@@ -46,7 +46,7 @@ softmaxtest(algo=CUDNN_SOFTMAX_ACCURATE)
 softmaxtest(algo=CUDNN_SOFTMAX_LOG)
 
 # BFloat16 tests
-ax,ay = randn(BFloat16,10,10),randn(BFloat16,10,10)
-cx,cy = CuArray.((ax,ay))
+ax, ay = randn(BFloat16, 10, 10), randn(BFloat16, 10, 10)
+cx, cy = CuArray.((ax, ay))
 softmaxtest()
-softmaxtest(algo=CUDNN_SOFTMAX_LOG)
+softmaxtest(algo = CUDNN_SOFTMAX_LOG)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant