Skip to content

Commit d317c8f

Browse files
Fix testmode batchnorm back (#739)
1 parent 0db7f45 commit d317c8f

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

lib/cudnn/batchnorm.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@ end
55

66
BNCache() = BNCache(nothing, nothing)
77

8-
@inline _wsize(y) = (map(_ -> 1, size(y)[1:end-2])..., size(y)[end-1], 1)
9-
10-
@inline _reddims(y) = (collect(1:ndims(y)-2)..., ndims(y))
8+
@inline _wsize(y) = (fill(1, ndims(y)-2)..., size(y)[end-1], 1)
119

1210
# NOTE: CuDNN supports only 4D and 5D Tensors for BatchNorm Operations
1311
# so reshape a 2D Tensor into 4D
@@ -110,7 +108,8 @@ function cudnnBNBackward!(dg::DenseCuArray{T}, g::DenseCuArray{T}, db::DenseCuAr
110108
else
111109
ivar = 1 ./ sqrt.(reshape(running_var, _wsize(x)) .+ eps)
112110
dx .= dy .* reshape(g, _wsize(x)) .* ivar
113-
dg .= squeeze(sum(dy .* (x .- reshape(running_mean, _wsize(x))) .* ivar, _reddims(dy)), dims = (1,2,4))
114-
db .= squeeze(sum(dy, _reddims(dy)), dims = (1,2,4))
111+
rdims = ((1:ndims(x)-2)..., ndims(x))
112+
dg .= vec(sum(dy .* (x .- reshape(running_mean, _wsize(x))) .* ivar, dims=rdims))
113+
db .= vec(sum(dy, dims=rdims))
115114
end
116115
end

test/cudnn/nnlib.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,5 +134,6 @@ end
134134
m = CUDA.rand(Float32, 2, 5)
135135
for training in (false, true)
136136
CUDNN.batchnorm(v, v, m, v, v, 1.0; training=training)
137+
CUDNN.∇batchnorm(v, v, m, m, v, v, 1.0; training=training)
137138
end
138139
end

0 commit comments

Comments
 (0)