Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor data ingesting in example #95

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 39 additions & 22 deletions examples/1_formation_energy/formation_energy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,45 @@ using Flux: @epochs
using ChemistryFeaturization
using AtomicGraphNets

function build_graphs(
prop,
featurization;
data_dir = nothing,
num_pts = nothing,
verbose = false,
)
info = CSV.read(joinpath(data_dir, prop * ".csv"), DataFrame)
y = Array(Float32.(info[!, Symbol(prop)]))

# shuffle data and pick out subset
indices = shuffle(1:size(info, 1))
if !isnothing(num_pts)
indices = indices[1:num_pts]
end
info = info[indices, :]
output = y[indices]

# next, make and featurize graphs
if verbose
println("Building graphs and feature vectors from structures...")
end

inputs = FeaturizedAtoms[]
id = "task_id"
for r in eachrow(info)
cifpath = string(joinpath(data_dir, prop * "_cifs", r[Symbol(id)] * ".cif"))
gr = AtomGraph(cifpath)
if gr === missing
continue
end
input = featurize(gr, featurization)
push!(inputs, input)
end
inputs, output
end

function train_formation_energy(;
num_pts = 100,
num_pts = 31417, # 100,
num_epochs = 5,
data_dir = joinpath(@__DIR__, "data"),
verbose = true,
Expand All @@ -21,7 +58,6 @@ function train_formation_energy(;
num_train = Int32(round(train_frac * num_pts))
num_test = num_pts - num_train
prop = "formation_energy_per_atom"
id = "task_id" # field by which to label each input material

# set up the featurization
featurization = GraphNodeFeaturization([
Expand All @@ -41,27 +77,8 @@ function train_formation_energy(;
num_hidden_layers = 1 # how many fully-connected layers after convolution and pooling?
opt = ADAM(0.001) # optimizer

# dataset...first, read in outputs
info = CSV.read(string(data_dir, prop, ".csv"), DataFrame)
y = Array(Float32.(info[!, Symbol(prop)]))

# shuffle data and pick out subset
indices = shuffle(1:size(info, 1))[1:num_pts]
info = info[indices, :]
output = y[indices]

# next, make and featurize graphs
if verbose
println("Building graphs and feature vectors from structures...")
end
inputs = FeaturizedAtoms[]

for r in eachrow(info)
cifpath = string(data_dir, prop, "_cifs/", r[Symbol(id)], ".cif")
gr = AtomGraph(cifpath)
input = featurize(gr, featurization)
push!(inputs, input)
end
inputs, output = build_graphs(prop, featurization; data_dir, num_pts, verbose)

# pick out train/test sets
if verbose
Expand Down