Skip to content

Commit

Permalink
bugfix: e3gnn parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
YutackPark committed Sep 16, 2024
1 parent 6d00fc7 commit c5905b7
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 120 deletions.
5 changes: 4 additions & 1 deletion sevenn/pair_e3gnn/comm_brick.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1068,7 +1068,10 @@ void CommBrick::forward_comm(PairE3GNNParallel *pair)
buf_send_ = reinterpret_cast<float*>(buf_send);
buf_recv_ = reinterpret_cast<float*>(buf_recv);
}
if (nswap > 6) error->all(FLERR,"PairE3GNNParallel: Cell size is too small. Please use a single GPU or replicate the cell.");
if(!comm_preprocess_done) {
pair->notify_proc_ids(sendproc, recvproc);
}
if (nswap > 6) error->all(FLERR,"PairE3GNNParallel: Cell size is too small. Please use a single GPU or make a supercell");

for (iswap = 0; iswap < nswap; iswap++) {
if(sendproc[iswap] == me) continue;
Expand Down
219 changes: 102 additions & 117 deletions sevenn/pair_e3gnn/pair_e3gnn_parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <c10/core/Scalar.h>
#include <c10/core/TensorOptions.h>
#include <cstdlib>
#include <filesystem>
#include <numeric>
#include <string>

Expand Down Expand Up @@ -121,6 +122,10 @@ PairE3GNNParallel::PairE3GNNParallel(LAMMPS *lmp) : Pair(lmp) {
device_name = "CPU";
}

if (std::getenv("OFF_E3GNN_PARALLEL_CUDA_MPI")) {
use_cuda_mpi = false;
}

if (lmp->screen) {
if (use_gpu && !use_cuda_mpi) {
device_comm = torch::kCPU;
Expand Down Expand Up @@ -159,26 +164,6 @@ torch::Device PairE3GNNParallel::get_cuda_device() {
if (print_info)
std::cout << world_rank << " Available # of GPUs found: " << num_gpus
<< std::endl;
if (cuda_visible == nullptr) {
// assume every gpu in node is avail
// believe user did right thing...
// idx = rank % num_gpus;
// std::cout << world_rank << " use GPU index(No CUDA_VISIBLE_DEVICES set):
// " << idx << std::endl;
} else {
/*
auto delim = ",";
char *tok = std::strtok(cuda_visible, delim);
std::vector<std::string> device_ids;
while(tok != nullptr) {
device_ids.push_back(std::string(tok));
tok = std::strtok(nullptr, delim);
}
idx = std::stoi(device_ids[rank % device_ids.size()]);
std::cout << world_rank << " use GPU index(from CUDA_VISIBLE_DEVICES): " <<
idx << std::endl;
*/
}
cudaError_t cuda_err = cudaSetDevice(idx);
if (cuda_err != cudaSuccess) {
std::cerr << "E3GNN: Failed to set CUDA device: "
Expand All @@ -203,23 +188,6 @@ bool PairE3GNNParallel::is_comm_preprocess_done() {
return comm_preprocess_done;
}

void PairE3GNNParallel::warning_pressure() {
static bool already_did = false;
if (!already_did && comm->me == 0) {
if (lmp->screen)
fprintf(
lmp->screen,
"WARNING: PairE3GNNParallel does not support pressure calculation. "
"Pressure on log is wrong. Use serial version if you needed\n");
if (lmp->logfile)
fprintf(
lmp->logfile,
"WARNING: PairE3GNNParallel does not support pressure calculation. "
"Pressure on log is wrong. Use serial version if you needed\n");
already_did = true;
}
}

void PairE3GNNParallel::compute(int eflag, int vflag) {
/*
Graph build on cpu
Expand All @@ -229,10 +197,7 @@ void PairE3GNNParallel::compute(int eflag, int vflag) {
else
evflag = vflag_fdotr = 0;
if (vflag_atom) {
error->all(FLERR, "atomic stress related feature is not supported\n");
}
if (vflag) {
warning_pressure();
error->all(FLERR, "atomic stress is not supported\n");
}

if (atom->tag_consecutive() == 0) {
Expand Down Expand Up @@ -332,7 +297,7 @@ void PairE3GNNParallel::compute(int eflag, int vflag) {
} // j loop end
} // i loop end

// memeber variable
// member variable
graph_size = graph_indexer;
const int ghost_node_num = graph_size - nlocal;

Expand Down Expand Up @@ -431,7 +396,7 @@ void PairE3GNNParallel::compute(int eflag, int vflag) {
std::vector<torch::Tensor> grads;
std::vector<torch::Tensor> of_tensor;

// TODO: most values of self_conn_grads were zero becuase we use only scalars
// TODO: most values of self_conn_grads were zero because we use only scalars
// for energy
for (auto rit = wrt_tensors.rbegin(); rit != wrt_tensors.rend(); ++rit) {
// edge_vec, x, x_ghost order
Expand Down Expand Up @@ -493,20 +458,48 @@ void PairE3GNNParallel::compute(int eflag, int vflag) {
dE_dr = dE_dr.to(torch::kCPU);
torch::Tensor force_tensor = torch::zeros({graph_indexer, 3});

force_tensor.scatter_reduce_(
0, edge_idx_src_tensor.repeat_interleave(3).view({nedges, 3}), dE_dr,
"sum");
force_tensor.scatter_reduce_(
0, edge_idx_dst_tensor.repeat_interleave(3).view({nedges, 3}),
torch::neg(dE_dr), "sum");
auto _edge_idx_src_tensor =
edge_idx_src_tensor.repeat_interleave(3).view({nedges, 3});
auto _edge_idx_dst_tensor =
edge_idx_dst_tensor.repeat_interleave(3).view({nedges, 3});

force_tensor.scatter_reduce_(0, _edge_idx_src_tensor, dE_dr, "sum");
force_tensor.scatter_reduce_(0, _edge_idx_dst_tensor, torch::neg(dE_dr),
"sum");

auto forces = force_tensor.accessor<float, 2>();

for (int graph_idx = 0; graph_idx < graph_indexer; graph_idx++) {
int i = graph_index_to_i[graph_idx];
f[i][0] = forces[graph_idx][0];
f[i][1] = forces[graph_idx][1];
f[i][2] = forces[graph_idx][2];
f[i][0] += forces[graph_idx][0];
f[i][1] += forces[graph_idx][1];
f[i][2] += forces[graph_idx][2];
}

if (vflag) {
auto diag = inp_edge_vec * dE_dr;
auto s12 = inp_edge_vec.select(1, 0) * dE_dr.select(1, 1);
auto s23 = inp_edge_vec.select(1, 1) * dE_dr.select(1, 2);
auto s31 = inp_edge_vec.select(1, 2) * dE_dr.select(1, 0);
std::vector<torch::Tensor> voigt_list = {
diag, s12.unsqueeze(-1), s23.unsqueeze(-1), s31.unsqueeze(-1)};
auto voigt = torch::cat(voigt_list, 1);

torch::Tensor per_atom_stress_tensor = torch::zeros({graph_indexer, 6});
auto _edge_idx_dst6_tensor =
edge_idx_dst_tensor.repeat_interleave(6).view({nedges, 6});
per_atom_stress_tensor.scatter_reduce_(0, _edge_idx_dst6_tensor, voigt,
"sum");
auto virial_stress_tensor =
torch::neg(torch::sum(per_atom_stress_tensor, 0));
auto virial_stress = virial_stress_tensor.accessor<float, 1>();

virial[0] += virial_stress[0];
virial[1] += virial_stress[1];
virial[2] += virial_stress[2];
virial[3] += virial_stress[3];
virial[4] += virial_stress[5];
virial[5] += virial_stress[4];
}

if (eflag_atom) {
Expand All @@ -526,13 +519,6 @@ void PairE3GNNParallel::compute(int eflag, int vflag) {
comm_index_pack_forward[i].clear();
comm_index_unpack_forward[i].clear();
comm_index_unpack_reverse[i].clear();
/*
if(use_cuda_mpi) {
comm_index_pack_forward_tensor[i].clear();
comm_index_unpack_forward_tensor[i].clear();
comm_index_unpack_reverse_tensor[i].clear();
}
*/
}

extra_graph_idx_map.clear();
Expand Down Expand Up @@ -563,7 +549,7 @@ void PairE3GNNParallel::coeff(int narg, char **arg) {

if (strcmp(arg[0], "*") != 0 || strcmp(arg[1], "*") != 0) {
error->all(FLERR,
"e3gnn: firt and second input of pair_coeff should be '*'");
"e3gnn: first and second input of pair_coeff should be '*'");
}
// expected input : pair_coeff * * pot.pth type_name1 type_name2 ...

Expand All @@ -579,13 +565,30 @@ void PairE3GNNParallel::coeff(int narg, char **arg) {

// model loading from input
int n_model = std::stoi(arg[2]);
try {
for (int i = 3; i < n_model + 3; i++) {
model_list.push_back(
torch::jit::load(std::string(arg[i]), device, meta_dict));
int chem_arg_i = 4;
std::vector<std::string> model_fnames;
if (std::filesystem::exists(arg[3])) {
if (std::filesystem::is_directory(arg[3])) {
auto headf = std::string(arg[3]);
for (int i = 0; i < n_model; i++) {
auto stri = std::to_string(i);
model_fnames.push_back(headf + "/deployed_parallel_" + stri + ".pt");
}
} else if (std::filesystem::is_regular_file(arg[3])) {
for (int i = 3; i < n_model + 3; i++) {
model_fnames.push_back(std::string(arg[i]));
}
chem_arg_i = n_model + 3;
} else {
error->all(FLERR, "No such file or directory:" + std::string(arg[3]));
}
} catch (const c10::Error &e) {
error->all(FLERR, "error loading the model, check the path of the model");
}

for (const auto &modelf : model_fnames) {
if (!std::filesystem::is_regular_file(modelf)) {
error->all(FLERR, "Expected this is a regular file:" + modelf);
}
model_list.push_back(torch::jit::load(modelf, device, meta_dict));
}

torch::jit::setGraphExecutorOptimize(false);
Expand Down Expand Up @@ -620,18 +623,19 @@ void PairE3GNNParallel::coeff(int narg, char **arg) {
tok = std::strtok(nullptr, delim);
}

// what if unkown chemical specie is in arg? should I abort? is there any use
// what if unknown chemical specie is in arg? should I abort? is there any use
// case for that?
bool found_flag = false;
for (int i = 3 + n_model; i < narg; i++) {
int n_chem = narg - chem_arg_i;
for (int i = 0; i < n_chem; i++) {
found_flag = false;
for (int j = 0; j < chem_vec.size(); j++) {
if (chem_vec[j].compare(arg[i]) == 0) {
map[i - 2 - n_model] = j; // store from 1, (not 0)
if (chem_vec[j].compare(arg[i + chem_arg_i]) == 0) {
map[i + 1] = j; // store from 1, (not 0)
found_flag = true;
if (lmp->logfile) {
fprintf(lmp->logfile, "Chemical specie '%s' is assigned to type %d\n",
arg[i], i - 2 - n_model);
arg[i + chem_arg_i], i + 1);
break;
}
}
Expand Down Expand Up @@ -670,25 +674,34 @@ void PairE3GNNParallel::init_style() {

double PairE3GNNParallel::init_one(int i, int j) { return cutoff; }

void PairE3GNNParallel::notify_proc_ids(const int *sendproc, const int *recvproc) {
for (int iswap = 0; iswap < 6; iswap++) {
this->sendproc[iswap] = sendproc[iswap];
this->recvproc[iswap]= recvproc[iswap];
}
}

void PairE3GNNParallel::comm_preprocess() {
assert(!comm_preprocess_done);
CommBrick *comm_brick = dynamic_cast<CommBrick *>(comm);

// false communication to preprocess index
// result in completed comm_index_pack/unpack_forward & extra_graph_idx_map
// fake lammps communication call to preprocess index
// gives complete comm_index_pack, unpack_forward, and extra_graph_idx_map
comm_brick->forward_comm(this);

std::set<int> already_met;
std::map<int, std::set<int>> already_met_map;

for (int comm_phase = 0; comm_phase < 6; comm_phase++) {
const int n = comm_index_pack_forward[comm_phase].size();
if (n == 0) {
// do nothing if self comm
continue;
int sproc = this->sendproc[comm_phase];
if (already_met_map.count(sproc) == 0) {
already_met_map.insert({sproc, std::set<int>()});
}

// for unpack_reverse, Ignore duplicated index by 'already_met'
std::vector<long> &idx_map_forward = comm_index_pack_forward[comm_phase];
std::vector<long> &idx_map_reverse = comm_index_unpack_reverse[comm_phase];
std::set<int>& already_met = already_met_map[sproc];
// the last index of x_comm is used to trash unnecessary values
const int trash_index =
graph_size + static_cast<int>(extra_graph_idx_map.size()); //+ 1;
Expand All @@ -707,19 +720,11 @@ void PairE3GNNParallel::comm_preprocess() {
}

if (use_cuda_mpi) {
comm_index_pack_forward_tensor[comm_phase] =
torch::from_blob(idx_map_forward.data(), idx_map_forward.size(),
INTEGER_TYPE)
.to(device);
comm_index_pack_forward_tensor[comm_phase] = torch::from_blob(idx_map_forward.data(), idx_map_forward.size(), INTEGER_TYPE).to(device);

auto upmap = comm_index_unpack_forward[comm_phase];
comm_index_unpack_forward_tensor[comm_phase] =
torch::from_blob(upmap.data(), upmap.size(), INTEGER_TYPE).to(device);

comm_index_unpack_reverse_tensor[comm_phase] =
torch::from_blob(idx_map_reverse.data(), idx_map_reverse.size(),
INTEGER_TYPE)
.to(device);
comm_index_unpack_forward_tensor[comm_phase] = torch::from_blob(upmap.data(), upmap.size(), INTEGER_TYPE).to(device);
comm_index_unpack_reverse_tensor[comm_phase] = torch::from_blob(idx_map_reverse.data(), idx_map_reverse.size(), INTEGER_TYPE).to(device);
}
}
comm_preprocess_done = true;
Expand Down Expand Up @@ -783,24 +788,12 @@ void PairE3GNNParallel::unpack_forward_init(int n, int first, int comm_phase) {
int PairE3GNNParallel::pack_forward_comm_gnn(float *buf, int comm_phase) {
std::vector<long> &idx_map = comm_index_pack_forward[comm_phase];
const int n = static_cast<int>(idx_map.size());

if (use_cuda_mpi) {
if (use_cuda_mpi && n != 0) {
torch::Tensor &idx_map_tensor = comm_index_pack_forward_tensor[comm_phase];
auto selected =
x_comm.index_select(0, idx_map_tensor); // its size is x_dim * n
auto selected = x_comm.index_select(0, idx_map_tensor); // its size is x_dim * n
cudaError_t cuda_err =
cudaMemcpy(buf, selected.data_ptr<float>(), (x_dim * n) * sizeof(float),
cudaMemcpyDeviceToDevice);

// TODO: I want to remove temporary selected tensor for speed.
// Code below produce wrong results. But if I change {n, x_dim} to {x_dim,
// n}, get correct result. Instead, it raises a warning that the dimension
// is not correct so they implicitly changed resulting tensor shape to fit
// out_tensor(buf_tensor)'s dimension. How can I sovle this?

// auto buf_tensor = torch::from_blob(buf, {n, x_dim},
// FLOAT_TYPE.device(device)); // tensor wrapping of buf
// at::index_select_out(buf_tensor, x_comm, 0, idx_map_tensor);
} else {
int i, j, m;
m = 0;
Expand Down Expand Up @@ -828,11 +821,8 @@ void PairE3GNNParallel::unpack_forward_comm_gnn(float *buf, int comm_phase) {
std::vector<long> &idx_map = comm_index_unpack_forward[comm_phase];
const int n = static_cast<int>(idx_map.size());

if (use_cuda_mpi) {
torch::Tensor &idx_map_tensor =
comm_index_unpack_forward_tensor[comm_phase];
// share same memory space with exisitng device buffer just wrapping to
// troch::Tensor
if (use_cuda_mpi && n != 0) {
torch::Tensor &idx_map_tensor = comm_index_unpack_forward_tensor[comm_phase];
auto buf_tensor =
torch::from_blob(buf, {n, x_dim}, FLOAT_TYPE.device(device));
x_comm.scatter_(0, idx_map_tensor.repeat_interleave(x_dim).view({n, x_dim}),
Expand All @@ -854,13 +844,10 @@ int PairE3GNNParallel::pack_reverse_comm_gnn(float *buf, int comm_phase) {
std::vector<long> &idx_map = comm_index_unpack_forward[comm_phase];
const int n = static_cast<int>(idx_map.size());

if (use_cuda_mpi) {
torch::Tensor &idx_map_tensor =
comm_index_unpack_forward_tensor[comm_phase];
if (use_cuda_mpi && n != 0) {
torch::Tensor &idx_map_tensor = comm_index_unpack_forward_tensor[comm_phase];
auto selected = x_comm.index_select(0, idx_map_tensor);
cudaError_t cuda_err =
cudaMemcpy(buf, selected.data_ptr<float>(), (x_dim * n) * sizeof(float),
cudaMemcpyDeviceToDevice);
cudaError_t cuda_err = cudaMemcpy(buf, selected.data_ptr<float>(), (x_dim * n) * sizeof(float), cudaMemcpyDeviceToDevice);
} else {
int i, j, m;
m = 0;
Expand All @@ -879,7 +866,6 @@ int PairE3GNNParallel::pack_reverse_comm_gnn(float *buf, int comm_phase) {
std::cout << world_rank << " pack_reverse x_dim*n: " << x_dim * n
<< std::endl;
double Msend = static_cast<double>(x_dim * n * 4) / (1024 * 1024);
std::cout << world_rank << " send size(MB): " << Msend << "\n" << std::endl;
}
return x_dim * n;
}
Expand All @@ -888,9 +874,8 @@ void PairE3GNNParallel::unpack_reverse_comm_gnn(float *buf, int comm_phase) {
std::vector<long> &idx_map = comm_index_unpack_reverse[comm_phase];
const int n = static_cast<int>(idx_map.size());

if (use_cuda_mpi) {
torch::Tensor &idx_map_tensor =
comm_index_unpack_reverse_tensor[comm_phase];
if (use_cuda_mpi && n != 0) {
torch::Tensor &idx_map_tensor = comm_index_unpack_reverse_tensor[comm_phase];
auto buf_tensor =
torch::from_blob(buf, {n, x_dim}, FLOAT_TYPE.device(device));
x_comm.scatter_(0, idx_map_tensor.repeat_interleave(x_dim).view({n, x_dim}),
Expand Down
Loading

0 comments on commit c5905b7

Please sign in to comment.