diff --git a/Project.toml b/Project.toml index 34aadfc..18f25dc 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ version = "0.1.3" AtomsBase = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a" AtomsIO = "1692102d-eeb4-4df9-807b-c9517f998d44" Cairo = "159f3aea-2a34-519c-b102-8c37f9878175" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChemistryFeaturization = "6c925690-434a-421d-aea7-51398c5b007a" Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab" @@ -22,6 +23,8 @@ SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" Xtals = "ede5f01d-793e-4c47-9885-c447d1f18d6d" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AtomsBase = "0.2, 0.3" @@ -38,7 +41,7 @@ NearestNeighbors = "0.4" PythonCall = "0.9.10" SimpleWeightedGraphs = "1.2" StaticArrays = "1.5" -Unitful = "1.12" +Unitful = "1.12" Xtals = "0.3, 0.4" julia = "1.6, 1.7" diff --git a/src/AtomGraphs.jl b/src/AtomGraphs.jl index bd958c4..410defe 100644 --- a/src/AtomGraphs.jl +++ b/src/AtomGraphs.jl @@ -10,6 +10,8 @@ using MolecularGraph import AtomsBase: atomic_symbol import ChemistryFeaturization: elements +using ChainRulesCore +using ZygoteRules include("atomgraph.jl") export AtomGraph, elements, visualize @@ -17,4 +19,6 @@ export AtomGraph, elements, visualize include("graph_building.jl") export inverse_square, exp_decay +include("adjoints.jl") + end diff --git a/src/adjoints.jl b/src/adjoints.jl new file mode 100644 index 0000000..7fa6e00 --- /dev/null +++ b/src/adjoints.jl @@ -0,0 +1,50 @@ +using Zygote +using ZygoteRules: @adjoint + +@adjoint function Dict(g::Base.Generator) + ys, backs = Zygote.unzip([Zygote.pullback(g.f, args) for args in g.iter]) + Dict(ys...), Δ -> begin + dd = Dict(k => b(Δ)[1].second for (b,(k,v)) in zip(backs, pairs(Δ))) + ((x for x in dd),) + end +end + +@adjoint function _cutoff!(weight_mat, f, ijd, + nb_counts, longest_dists; + max_num_nbr = 12) + y, ld = _cutoff!(weight_mat, f, ijd, + nb_counts, longest_dists; + max_num_nbr = max_num_nbr) + function cutoff_pb((Δ,nt)) + s = size(Δ) + Δ = vec(collect(Δ)) + for (ix, (_,_,d)) in zip(eachindex(Δ), ijd) + y_, back_ = Zygote.pullback(f, d) + Δ[ix] *= first(back_(Zygote.sensitivity(d))) + end + (reshape(Δ, s), nothing, + collect(zip(fill(nothing, size(Δ,1)), + fill(nothing, size(Δ,1)), + Δ)), + nothing, + nothing) + end + + (y,ld), cutoff_pb +end + +function Zygote.ChainRulesCore.rrule(::Type{T}, x...) where T <: SArray{D, Ts, ND, L} where {D, Ts, ND, L} + y = SArray{D, Ts, ND, L}(x...) + function sarray_pb(Δy) + Δy = map(t->eltype(x...)(t...), Δy) + return NoTangent(), (Δy...,) + end + return y, sarray_pb +end + +Zygote.@adjoint function Unitful.ustrip(x::Quantity{T, D, U}) where {T, D, U} + function back(Δ) + (Quantity{T, D, U}(Δ), ) + end + Unitful.ustrip(x), back +end \ No newline at end of file diff --git a/src/graph_building.jl b/src/graph_building.jl index 0fe0c9b..94b05f7 100644 --- a/src/graph_building.jl +++ b/src/graph_building.jl @@ -10,6 +10,8 @@ include("utils.jl") inverse_square(x) = x^-2.0 exp_decay(x) = exp(-x) +const pymat_structure = pyimport("pymatgen.core.structure") + """ Build graph from a file storing a crystal structure (will be read in using AtomsIO, which in turn calls ASE). Returns the weight matrix and elements used for constructing an `AtomGraph`. @@ -38,8 +40,7 @@ function build_graph( if use_voronoi @info "Note that building neighbor lists and edge weights via the Voronoi method requires the assumption of periodic boundaries. If you are building a graph for a molecule, you probably do not want this..." - s = pyimport("pymatgen.core.structure") - struc = s.Structure.from_file(file_path) + struc = pymat_structure..Structure.from_file(file_path) weight_mat = weights_voronoi(struc) return weight_mat, atom_ids, struc else @@ -83,7 +84,7 @@ function build_graph( max_num_nbr = max_num_nbr, dist_decay_func = dist_decay_func, ) - return weight_mat, String.(atomic_symbol(sys)), sys + return weight_mat # , string.(atomic_symbol(sys)), sys end """ @@ -103,21 +104,38 @@ function weights_cutoff(is, js, dists; max_num_nbr = 12, dist_decay_func = inver # iterate over list of tuples to build edge weights... # note that neighbor list double counts so we only have to increment one counter per pair - weight_mat = zeros(Float32, num_atoms, num_atoms) + weight_mat = zeros(Float64, round(Int,num_atoms), round(Int,num_atoms)) + weight_mat, longest_dists = _cutoff!(weight_mat, + dist_decay_func, + ijd, + nb_counts, + longest_dists) + + # average across diagonal, just in case + weight_mat = 0.5 .* (weight_mat .+ weight_mat') + + # normalize weights + weight_mat = weight_mat ./ maximum(weight_mat) + weight_mat +end + +function _cutoff!(weight_mat, f, ijd, + nb_counts, longest_dists; max_num_nbr = 12) + for (i, j, d) in ijd + # FiniteDifferences doesn't like non integers as indices + # and is used to test + i, j = round.(Int, (i,j)) + # if we're under the max OR if it's at the same distance as the previous one if nb_counts[i] < max_num_nbr || isapprox(longest_dists[i], d) - weight_mat[i, j] += dist_decay_func(d) + weight_mat[i, j] += f(d) longest_dists[i] = d nb_counts[i] += 1 end end - # average across diagonal, just in case - weight_mat = 0.5 .* (weight_mat .+ weight_mat') - - # normalize weights - weight_mat = weight_mat ./ maximum(weight_mat) + weight_mat, longest_dists end """ @@ -183,29 +201,33 @@ function neighbor_list(sys; cutoff_radius::Real = 8.0) cutoff_radius = 0.99 * min_celldim end + get_pairdist((i,j)) = sqrt(sum((sc_pos[:, i] .- sc_pos[:, j]).^2)) + index_map(i) = (i - 1) % n_atoms + 1 # I suddenly understand why some people dislike 1-based indexing + # todo: try BallTree, also perhaps other leafsize values # also, the whole supercell thing could probably be avoided (and this function sped up substantially) by doing this using something like: # ptree = BruteTree(hcat(ustrip.(position(s))...), PeriodicEuclidean([1,1,1])) # but I don't have time to carefully test that right now and I know the supercell thing should work - tree = BruteTree(sc_pos) - is_raw = 13*n_atoms+1:14*n_atoms - js_raw = inrange(tree, sc_pos[:, is_raw], cutoff_radius) - - index_map(i) = (i - 1) % n_atoms + 1 # I suddenly understand why some people dislike 1-based indexing - - # this looks horrifying but it does do the right thing... - #ijraw_pairs = [p for p in Iterators.flatten([Iterators.product([p for p in zip(is_raw, js_raw)][n]...) for n in 1:4]) if p[1]!=p[2]] - split1 = map(zip(is_raw, js_raw)) do x - return [ - p for p in [(x[1], [j for j in js if j != x[1]]...) for js in x[2]] if - length(p) == 2 - ] + is, js, ijraw_pairs = Zygote.ignore() do + tree = BruteTree(sc_pos) + is_raw = 13*n_atoms+1:14*n_atoms + js_raw = inrange(tree, sc_pos[:, is_raw], cutoff_radius) + + + # this looks horrifying but it does do the right thing... + #ijraw_pairs = [p for p in Iterators.flatten([Iterators.product([p for p in zip(is_raw, js_raw)][n]...) for n in 1:4]) if p[1]!=p[2]] + split1 = map(zip(is_raw, js_raw)) do x + return [ + p for p in [(x[1], [j for j in js if j != x[1]]...) for js in x[2]] if + length(p) == 2 + ] + end + ijraw_pairs = [(split1...)...] + is = index_map.([t[1] for t in ijraw_pairs]) + js = index_map.([t[2] for t in ijraw_pairs]) + is, js, ijraw_pairs end - ijraw_pairs = [(split1...)...] - get_pairdist((i,j)) = sqrt(sum((sc_pos[:, i] .- sc_pos[:, j]).^2)) dists = get_pairdist.(ijraw_pairs) - is = index_map.([t[1] for t in ijraw_pairs]) - js = index_map.([t[2] for t in ijraw_pairs]) return is, js, dists end diff --git a/src/utils.jl b/src/utils.jl index af132ad..1e113c4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -12,23 +12,23 @@ using AtomsBase """ function build_supercell(sys::AbstractSystem, repfactors) @assert length(repfactors) == n_dimensions(sys) "Your list of replication factors doesn't match the dimensionality of the system!" - + # @show repfactors old_box = bounding_box(sys) new_box = repfactors .* old_box - symbols = repeat(atomic_symbol(sys), prod(repfactors)) + symbols = Zygote.ignore() do + repeat(atomic_symbol(sys), prod(repfactors)) + end - integer_offsets = Iterators.product(range.(Ref(0), repfactors .- 1, step=1)...) + integer_offsets = Iterators.product(range.(0, repfactors .- 1, step=1)...) position_offsets = [sum(offset .* old_box) for offset in integer_offsets] - old_positions = position(sys) - new_positions = repeat(MArray.(position(sys)), prod(repfactors)) - - for (i, offset) in enumerate(position_offsets) - indices = (1:length(sys)) .+ (i-1) * length(sys) - for (j, pos) in enumerate(old_positions) - new_positions[indices[j]] = pos .+ offset - end - end - return new_box, symbols, SVector.(new_positions) -end \ No newline at end of file + a = map(enumerate(position_offsets)) do (i,offset) + map(enumerate(old_positions)) do (j, pos) + pos .+ offset + end + end + new_positions = reduce(vcat, vec(a)) + + return new_box, symbols, new_positions +end