We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 1678fd1 commit 9916e14Copy full SHA for 9916e14
src/weights.jl
@@ -380,9 +380,7 @@ Base.:(==)(x::AbstractWeights, y::AbstractWeights) = false
380
381
Compute the weighted sum of an array `v` with weights `w`, optionally over the dimension `dim`.
382
"""
383
-wsum(v::AbstractVector, w::AbstractVector) = dot(v, w)
384
-wsum(v::AbstractArray, w::AbstractVector) = dot(vec(v), w)
385
-wsum(v::AbstractArray, w::AbstractVector, dims::Colon) = wsum(v, w)
+wsum(v::AbstractArray, w::AbstractVector, dims::Colon=:) = transpose(w) * vec(v)
386
387
## wsum along dimension
388
#
test/weights.jl
@@ -239,6 +239,8 @@ a = reshape(1.0:27.0, 3, 3, 3)
239
@testset "Sum $f" for f in weight_funcs
240
@test sum([1.0, 2.0, 3.0], f([1.0, 0.5, 0.5])) ≈ 3.5
241
@test sum(1:3, f([1.0, 1.0, 0.5])) ≈ 4.5
242
+ @test sum([1 + 2im, 2 + 3im], f([1.0, 0.5])) ≈ 2 + 3.5im
243
+ @test sum([[1, 2], [3, 4]], f([2, 3])) == [11, 16]
244
245
for wt in ([1.0, 1.0, 1.0], [1.0, 0.2, 0.0], [0.2, 0.0, 1.0])
246
@test sum(a, f(wt), dims=1) ≈ sum(a.*reshape(wt, length(wt), 1, 1), dims=1)
@@ -250,6 +252,7 @@ end
250
252
@testset "Mean $f" for f in weight_funcs
251
253
@test mean([1:3;], f([1.0, 1.0, 0.5])) ≈ 1.8
254
@test mean(1:3, f([1.0, 1.0, 0.5])) ≈ 1.8
255
+ @test mean([1 + 2im, 4 + 5im], f([1.0, 0.5])) ≈ 2 + 3im
256
257
258
@test mean(a, f(wt), dims=1) ≈ sum(a.*reshape(wt, length(wt), 1, 1), dims=1)/sum(wt)
0 commit comments