diff --git a/examples/Project.toml b/examples/Project.toml index 3d950f665..6145e4e5a 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -3,7 +3,9 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0" DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +GeometricFlux = "7e08b658-56d3-11e9-2997-919d5b31e4ea" GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694" +GraphSignals = "3ebe565e-a4b5-49c6-aed2-300248c3a9c1" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" diff --git a/examples/node_classification_cora_geometricflux.jl b/examples/node_classification_cora_geometricflux.jl new file mode 100644 index 000000000..4188bb579 --- /dev/null +++ b/examples/node_classification_cora_geometricflux.jl @@ -0,0 +1,87 @@ +# An example of semi-supervised node classification + +using Flux +using Flux: onecold, onehotbatch +using Flux.Losses: logitcrossentropy +using GeometricFlux, GraphSignals +using MLDatasets: Cora +using Statistics, Random +using CUDA +CUDA.allowscalar(false) + +function eval_loss_accuracy(X, y, ids, model) + ŷ = model(X) + l = logitcrossentropy(ŷ[:,ids], y[:,ids]) + acc = mean(onecold(ŷ[:,ids]) .== onecold(y[:,ids])) + return (loss = round(l, digits=4), acc = round(acc*100, digits=2)) +end + +# arguments for the `train` function +Base.@kwdef mutable struct Args + η = 1f-3 # learning rate + epochs = 100 # number of epochs + seed = 17 # set seed > 0 for reproducibility + usecuda = true # if true use cuda (if available) + nhidden = 128 # dimension of hidden features + infotime = 10 # report every `infotime` epochs +end + +function train(; kws...) + args = Args(; kws...) + + args.seed > 0 && Random.seed!(args.seed) + + if args.usecuda && CUDA.functional() + device = gpu + args.seed > 0 && CUDA.seed!(args.seed) + @info "Training on GPU" + else + device = cpu + @info "Training on CPU" + end + + # LOAD DATA + data = Cora.dataset() + g = FeaturedGraph(data.adjacency_list) |> device + X = data.node_features |> device + y = onehotbatch(data.node_labels, 1:data.num_classes) |> device + train_ids = data.train_indices |> device + val_ids = data.val_indices |> device + test_ids = data.test_indices |> device + ytrain = y[:,train_ids] + + nin, nhidden, nout = size(X,1), args.nhidden, data.num_classes + + ## DEFINE MODEL + model = Chain(GCNConv(g, nin => nhidden, relu), + Dropout(0.5), + GCNConv(g, nhidden => nhidden, relu), + Dense(nhidden, nout)) |> device + + ps = Flux.params(model) + opt = ADAM(args.η) + + @info g + + ## LOGGING FUNCTION + function report(epoch) + train = eval_loss_accuracy(X, y, train_ids, model) + test = eval_loss_accuracy(X, y, test_ids, model) + println("Epoch: $epoch Train: $(train) Test: $(test)") + end + + ## TRAINING + report(0) + for epoch in 1:args.epochs + gs = Flux.gradient(ps) do + ŷ = model(X) + logitcrossentropy(ŷ[:,train_ids], ytrain) + end + + Flux.Optimise.update!(opt, ps, gs) + + epoch % args.infotime == 0 && report(epoch) + end +end + +train(usecuda=false)