Skip to content

Commit

Permalink
feat: AD through graph building
Browse files Browse the repository at this point in the history
  • Loading branch information
DhairyaLGandhi committed Nov 17, 2023
1 parent 5422754 commit 59c6e69
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 65 deletions.
17 changes: 17 additions & 0 deletions src/adjoints.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using Zygote
using ZygoteRules: @adjoint

@adjoint function Dict(g::Base.Generator)
Expand Down Expand Up @@ -31,3 +32,19 @@ end

(y,ld), cutoff_pb
end

function Zygote.ChainRulesCore.rrule(::Type{T}, x...) where T <: SArray{D, Ts, ND, L} where {D, Ts, ND, L}
y = SArray{D, Ts, ND, L}(x...)
function sarray_pb(Δy)
Δy = map(t->eltype(x...)(t...), Δy)
return NoTangent(), (Δy...,)
end
return y, sarray_pb
end

Zygote.@adjoint function Unitful.ustrip(x::Quantity{T, D, U}) where {T, D, U}
function back(Δ)
(Quantity{T, D, U}(Δ), )
end
Unitful.ustrip(x), back
end
72 changes: 21 additions & 51 deletions src/graph_building.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,40 +87,6 @@ function build_graph(
return weight_mat # , string.(atomic_symbol(sys)), sys
end

# """
# Build graph using neighbor number cutoff method adapted from original CGCNN.
#
# !!! note
# `max_num_nbr` is a "soft" max, in that if there are more of the same distance as the last, all of those will be added.
# """
# function weights_cutoff(is, js, dists; max_num_nbr = 12, dist_decay_func = inverse_square)
# # sort by distance
# ijd = sort([t for t in zip(is, js, dists)], by = t -> t[3])
#
# # initialize neighbor counts
# num_atoms = maximum(is)
# local nb_counts = Dict(i => 0 for i = 1:num_atoms)
# local longest_dists = Dict(i => 0.0 for i = 1:num_atoms)
#
# # iterate over list of tuples to build edge weights...
# # note that neighbor list double counts so we only have to increment one counter per pair
# weight_mat = zeros(Float32, num_atoms, num_atoms)
# for (i, j, d) in ijd
# # if we're under the max OR if it's at the same distance as the previous one
# if nb_counts[i] < max_num_nbr || isapprox(longest_dists[i], d)
# weight_mat[i, j] += dist_decay_func(d)
# longest_dists[i] = d
# nb_counts[i] += 1
# end
# end
#
# # average across diagonal, just in case
# weight_mat = 0.5 .* (weight_mat .+ weight_mat')
#
# # normalize weights
# weight_mat = weight_mat ./ maximum(weight_mat)
# end

"""
Build graph using neighbor number cutoff method adapted from original CGCNN.
Expand Down Expand Up @@ -235,29 +201,33 @@ function neighbor_list(sys; cutoff_radius::Real = 8.0)
cutoff_radius = 0.99 * min_celldim
end

get_pairdist((i,j)) = sqrt(sum((sc_pos[:, i] .- sc_pos[:, j]).^2))
index_map(i) = (i - 1) % n_atoms + 1 # I suddenly understand why some people dislike 1-based indexing

# todo: try BallTree, also perhaps other leafsize values
# also, the whole supercell thing could probably be avoided (and this function sped up substantially) by doing this using something like:
# ptree = BruteTree(hcat(ustrip.(position(s))...), PeriodicEuclidean([1,1,1]))
# but I don't have time to carefully test that right now and I know the supercell thing should work
tree = BruteTree(sc_pos)
is_raw = 13*n_atoms+1:14*n_atoms
js_raw = inrange(tree, sc_pos[:, is_raw], cutoff_radius)

index_map(i) = (i - 1) % n_atoms + 1 # I suddenly understand why some people dislike 1-based indexing

# this looks horrifying but it does do the right thing...
#ijraw_pairs = [p for p in Iterators.flatten([Iterators.product([p for p in zip(is_raw, js_raw)][n]...) for n in 1:4]) if p[1]!=p[2]]
split1 = map(zip(is_raw, js_raw)) do x
return [
p for p in [(x[1], [j for j in js if j != x[1]]...) for js in x[2]] if
length(p) == 2
]
is, js, ijraw_pairs = Zygote.ignore() do
tree = BruteTree(sc_pos)
is_raw = 13*n_atoms+1:14*n_atoms
js_raw = inrange(tree, sc_pos[:, is_raw], cutoff_radius)


# this looks horrifying but it does do the right thing...
#ijraw_pairs = [p for p in Iterators.flatten([Iterators.product([p for p in zip(is_raw, js_raw)][n]...) for n in 1:4]) if p[1]!=p[2]]
split1 = map(zip(is_raw, js_raw)) do x
return [
p for p in [(x[1], [j for j in js if j != x[1]]...) for js in x[2]] if
length(p) == 2
]
end
ijraw_pairs = [(split1...)...]
is = index_map.([t[1] for t in ijraw_pairs])
js = index_map.([t[2] for t in ijraw_pairs])
is, js, ijraw_pairs
end
ijraw_pairs = [(split1...)...]
get_pairdist((i,j)) = sqrt(sum((sc_pos[:, i] .- sc_pos[:, j]).^2))
dists = get_pairdist.(ijraw_pairs)
is = index_map.([t[1] for t in ijraw_pairs])
js = index_map.([t[2] for t in ijraw_pairs])
return is, js, dists
end

Expand Down
28 changes: 14 additions & 14 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,23 @@ using AtomsBase
"""
function build_supercell(sys::AbstractSystem, repfactors)
@assert length(repfactors) == n_dimensions(sys) "Your list of replication factors doesn't match the dimensionality of the system!"

# @show repfactors
old_box = bounding_box(sys)
new_box = repfactors .* old_box
symbols = repeat(atomic_symbol(sys), prod(repfactors))
symbols = Zygote.ignore() do
repeat(atomic_symbol(sys), prod(repfactors))
end

integer_offsets = Iterators.product(range.(Ref(0), repfactors .- 1, step=1)...)
integer_offsets = Iterators.product(range.(0, repfactors .- 1, step=1)...)
position_offsets = [sum(offset .* old_box) for offset in integer_offsets]

old_positions = position(sys)
new_positions = repeat(MArray.(position(sys)), prod(repfactors))

for (i, offset) in enumerate(position_offsets)
indices = (1:length(sys)) .+ (i-1) * length(sys)
for (j, pos) in enumerate(old_positions)
new_positions[indices[j]] = pos .+ offset
end
end

return new_box, symbols, SVector.(new_positions)
end
a = map(enumerate(position_offsets)) do (i,offset)
map(enumerate(old_positions)) do (j, pos)
pos .+ offset
end
end
new_positions = reduce(vcat, vec(a))

return new_box, symbols, new_positions
end

0 comments on commit 59c6e69

Please sign in to comment.