Skip to content

Commit 3faac7b

Browse files
committed
add L2 regularization to gcn model
1 parent 8b3b8c8 commit 3faac7b

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

examples/gcn.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using GeometricFlux
2+
using GraphSignals
23
using Flux
34
using Flux: onehotbatch, onecold, logitcrossentropy, throttle
45
using Flux: @epochs
@@ -7,9 +8,6 @@ using Statistics
78
using SparseArrays
89
using Graphs.SimpleGraphs
910
using CUDA
10-
using Random
11-
12-
Random.seed!([0x6044b4da, 0xd873e4f9, 0x59d90c0a, 0xde01aa81])
1311

1412
@load "data/cora_features.jld2" features
1513
@load "data/cora_labels.jld2" labels
@@ -19,21 +17,25 @@ num_nodes = 2708
1917
num_features = 1433
2018
hidden = 16
2119
target_catg = 7
22-
epochs = 100
20+
epochs = 200
21+
λ = 5e-4
2322

2423
## Preprocessing data
2524
train_X = Matrix{Float32}(features) |> gpu # dim: num_features * num_nodes
2625
train_y = Matrix{Float32}(labels) |> gpu # dim: target_catg * num_nodes
27-
fg = FeaturedGraph(g) |> gpu
26+
fg = FeaturedGraph(g) # pass to gpu together in model layers
2827

2928
## Model
3029
model = Chain(GCNConv(fg, num_features=>hidden, relu),
3130
Dropout(0.5),
3231
GCNConv(fg, hidden=>target_catg),
33-
) |> gpu
32+
) |> gpu;
33+
# do not show model architecture, showing CuSparseMatrix will trigger errors
3434

3535
## Loss
36-
loss(x, y) = logitcrossentropy(model(x), y)
36+
l2norm(x) = sum(abs2, x)
37+
# cross entropy with first layer L2 regularization
38+
loss(x, y) = logitcrossentropy(model(x), y) + λ*sum(l2norm, Flux.params(model[1]))
3739
accuracy(x, y) = mean(onecold(softmax(cpu(model(x)))) .== onecold(cpu(y)))
3840

3941

0 commit comments

Comments
 (0)