Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate AD functionality to current AtomGraphs #11

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "0.1.3"
AtomsBase = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a"
AtomsIO = "1692102d-eeb4-4df9-807b-c9517f998d44"
Cairo = "159f3aea-2a34-519c-b102-8c37f9878175"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChemistryFeaturization = "6c925690-434a-421d-aea7-51398c5b007a"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab"
Expand All @@ -22,6 +23,8 @@ SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Xtals = "ede5f01d-793e-4c47-9885-c447d1f18d6d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
AtomsBase = "0.2, 0.3"
Expand All @@ -38,7 +41,7 @@ NearestNeighbors = "0.4"
PythonCall = "0.9.10"
SimpleWeightedGraphs = "1.2"
StaticArrays = "1.5"
Unitful = "1.12"
Unitful = "1.12"
Xtals = "0.3, 0.4"
julia = "1.6, 1.7"

Expand Down
4 changes: 4 additions & 0 deletions src/AtomGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,15 @@ using MolecularGraph
import AtomsBase: atomic_symbol

import ChemistryFeaturization: elements
using ChainRulesCore
using ZygoteRules

include("atomgraph.jl")
export AtomGraph, elements, visualize

include("graph_building.jl")
export inverse_square, exp_decay

include("adjoints.jl")

end
50 changes: 50 additions & 0 deletions src/adjoints.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
using Zygote
using ZygoteRules: @adjoint

@adjoint function Dict(g::Base.Generator)
ys, backs = Zygote.unzip([Zygote.pullback(g.f, args) for args in g.iter])
Dict(ys...), Δ -> begin
dd = Dict(k => b(Δ)[1].second for (b,(k,v)) in zip(backs, pairs(Δ)))
((x for x in dd),)
end
end

@adjoint function _cutoff!(weight_mat, f, ijd,
nb_counts, longest_dists;
max_num_nbr = 12)
y, ld = _cutoff!(weight_mat, f, ijd,
nb_counts, longest_dists;
max_num_nbr = max_num_nbr)
function cutoff_pb((Δ,nt))
s = size(Δ)
Δ = vec(collect(Δ))
for (ix, (_,_,d)) in zip(eachindex(Δ), ijd)
y_, back_ = Zygote.pullback(f, d)
Δ[ix] *= first(back_(Zygote.sensitivity(d)))
end
(reshape(Δ, s), nothing,
collect(zip(fill(nothing, size(Δ,1)),
fill(nothing, size(Δ,1)),
Δ)),
nothing,
nothing)
end

(y,ld), cutoff_pb
end

function Zygote.ChainRulesCore.rrule(::Type{T}, x...) where T <: SArray{D, Ts, ND, L} where {D, Ts, ND, L}
y = SArray{D, Ts, ND, L}(x...)
function sarray_pb(Δy)
Δy = map(t->eltype(x...)(t...), Δy)
return NoTangent(), (Δy...,)
end
return y, sarray_pb
end

Zygote.@adjoint function Unitful.ustrip(x::Quantity{T, D, U}) where {T, D, U}
function back(Δ)
(Quantity{T, D, U}(Δ), )
end
Unitful.ustrip(x), back
end
76 changes: 49 additions & 27 deletions src/graph_building.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ include("utils.jl")
inverse_square(x) = x^-2.0
exp_decay(x) = exp(-x)

const pymat_structure = pyimport("pymatgen.core.structure")

"""
Build graph from a file storing a crystal structure (will be read in using AtomsIO, which in turn calls ASE). Returns the weight matrix and elements used for constructing an `AtomGraph`.

Expand Down Expand Up @@ -38,8 +40,7 @@ function build_graph(

if use_voronoi
@info "Note that building neighbor lists and edge weights via the Voronoi method requires the assumption of periodic boundaries. If you are building a graph for a molecule, you probably do not want this..."
s = pyimport("pymatgen.core.structure")
struc = s.Structure.from_file(file_path)
struc = pymat_structure..Structure.from_file(file_path)
weight_mat = weights_voronoi(struc)
return weight_mat, atom_ids, struc
else
Expand Down Expand Up @@ -83,7 +84,7 @@ function build_graph(
max_num_nbr = max_num_nbr,
dist_decay_func = dist_decay_func,
)
return weight_mat, String.(atomic_symbol(sys)), sys
return weight_mat # , string.(atomic_symbol(sys)), sys
end

"""
Expand All @@ -103,21 +104,38 @@ function weights_cutoff(is, js, dists; max_num_nbr = 12, dist_decay_func = inver

# iterate over list of tuples to build edge weights...
# note that neighbor list double counts so we only have to increment one counter per pair
weight_mat = zeros(Float32, num_atoms, num_atoms)
weight_mat = zeros(Float64, round(Int,num_atoms), round(Int,num_atoms))
weight_mat, longest_dists = _cutoff!(weight_mat,
dist_decay_func,
ijd,
nb_counts,
longest_dists)

# average across diagonal, just in case
weight_mat = 0.5 .* (weight_mat .+ weight_mat')

# normalize weights
weight_mat = weight_mat ./ maximum(weight_mat)
weight_mat
end

function _cutoff!(weight_mat, f, ijd,
nb_counts, longest_dists; max_num_nbr = 12)

for (i, j, d) in ijd
# FiniteDifferences doesn't like non integers as indices
# and is used to test
i, j = round.(Int, (i,j))

# if we're under the max OR if it's at the same distance as the previous one
if nb_counts[i] < max_num_nbr || isapprox(longest_dists[i], d)
weight_mat[i, j] += dist_decay_func(d)
weight_mat[i, j] += f(d)
longest_dists[i] = d
nb_counts[i] += 1
end
end

# average across diagonal, just in case
weight_mat = 0.5 .* (weight_mat .+ weight_mat')

# normalize weights
weight_mat = weight_mat ./ maximum(weight_mat)
weight_mat, longest_dists
end

"""
Expand Down Expand Up @@ -183,29 +201,33 @@ function neighbor_list(sys; cutoff_radius::Real = 8.0)
cutoff_radius = 0.99 * min_celldim
end

get_pairdist((i,j)) = sqrt(sum((sc_pos[:, i] .- sc_pos[:, j]).^2))
index_map(i) = (i - 1) % n_atoms + 1 # I suddenly understand why some people dislike 1-based indexing

# todo: try BallTree, also perhaps other leafsize values
# also, the whole supercell thing could probably be avoided (and this function sped up substantially) by doing this using something like:
# ptree = BruteTree(hcat(ustrip.(position(s))...), PeriodicEuclidean([1,1,1]))
# but I don't have time to carefully test that right now and I know the supercell thing should work
tree = BruteTree(sc_pos)
is_raw = 13*n_atoms+1:14*n_atoms
js_raw = inrange(tree, sc_pos[:, is_raw], cutoff_radius)

index_map(i) = (i - 1) % n_atoms + 1 # I suddenly understand why some people dislike 1-based indexing

# this looks horrifying but it does do the right thing...
#ijraw_pairs = [p for p in Iterators.flatten([Iterators.product([p for p in zip(is_raw, js_raw)][n]...) for n in 1:4]) if p[1]!=p[2]]
split1 = map(zip(is_raw, js_raw)) do x
return [
p for p in [(x[1], [j for j in js if j != x[1]]...) for js in x[2]] if
length(p) == 2
]
is, js, ijraw_pairs = Zygote.ignore() do
tree = BruteTree(sc_pos)
is_raw = 13*n_atoms+1:14*n_atoms
js_raw = inrange(tree, sc_pos[:, is_raw], cutoff_radius)


# this looks horrifying but it does do the right thing...
#ijraw_pairs = [p for p in Iterators.flatten([Iterators.product([p for p in zip(is_raw, js_raw)][n]...) for n in 1:4]) if p[1]!=p[2]]
split1 = map(zip(is_raw, js_raw)) do x
return [
p for p in [(x[1], [j for j in js if j != x[1]]...) for js in x[2]] if
length(p) == 2
]
end
ijraw_pairs = [(split1...)...]
is = index_map.([t[1] for t in ijraw_pairs])
js = index_map.([t[2] for t in ijraw_pairs])
is, js, ijraw_pairs
end
ijraw_pairs = [(split1...)...]
get_pairdist((i,j)) = sqrt(sum((sc_pos[:, i] .- sc_pos[:, j]).^2))
dists = get_pairdist.(ijraw_pairs)
is = index_map.([t[1] for t in ijraw_pairs])
js = index_map.([t[2] for t in ijraw_pairs])
return is, js, dists
end

Expand Down
28 changes: 14 additions & 14 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,23 @@ using AtomsBase
"""
function build_supercell(sys::AbstractSystem, repfactors)
@assert length(repfactors) == n_dimensions(sys) "Your list of replication factors doesn't match the dimensionality of the system!"

# @show repfactors
old_box = bounding_box(sys)
new_box = repfactors .* old_box
symbols = repeat(atomic_symbol(sys), prod(repfactors))
symbols = Zygote.ignore() do
repeat(atomic_symbol(sys), prod(repfactors))
end

integer_offsets = Iterators.product(range.(Ref(0), repfactors .- 1, step=1)...)
integer_offsets = Iterators.product(range.(0, repfactors .- 1, step=1)...)
position_offsets = [sum(offset .* old_box) for offset in integer_offsets]

old_positions = position(sys)
new_positions = repeat(MArray.(position(sys)), prod(repfactors))

for (i, offset) in enumerate(position_offsets)
indices = (1:length(sys)) .+ (i-1) * length(sys)
for (j, pos) in enumerate(old_positions)
new_positions[indices[j]] = pos .+ offset
end
end

return new_box, symbols, SVector.(new_positions)
end
a = map(enumerate(position_offsets)) do (i,offset)
map(enumerate(old_positions)) do (j, pos)
pos .+ offset
end
end
new_positions = reduce(vcat, vec(a))

return new_box, symbols, new_positions
end
Loading