-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathCudaDomdecGroups.h
151 lines (118 loc) · 3.98 KB
/
CudaDomdecGroups.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#ifndef CUDADOMDECGROUPS_H
#define CUDADOMDECGROUPS_H
#include <vector>
#include <map>
#include <algorithm>
#include "DomdecGroups.h"
#include "Bonded_struct.h"
#include "CudaDomdec.h"
#include "CudaAtomGroup.h"
class CudaDomdecGroups : public DomdecGroups {
private:
const CudaDomdec& domdec;
// Atom group pointers in a map:
// <id, atomgroup*>
std::map<int, AtomGroupBase*> atomGroups;
// Atom group pointers in a vector:
std::vector<AtomGroupBase*> atomGroupVector;
int** groupTable;
std::vector<int*> h_groupTable;
int* groupDataStart;
int* groupData;
int* groupTablePos;
int* h_groupTablePos;
bool tbl_upto_date;
// Storage vector used for registering groups
std::vector< std::vector<int> > regGroups;
// True for Group structures that also contain constraint groups
bool hasConstGroups;
// Start of constraint groups
int typeConstStart;
int* constTable;
int* constTablePos;
int* h_constTablePos;
//int** nodeTable;
//int* nodeTablePos;
// NOTE: this contains the device pointers. Only kept for deallocation
//int** h_nodeTable;
int coordTmpLen;
int* coordTmp;
int coordIndLen;
int* coordInd;
int* neighPos;
int* h_neighPos;
public:
CudaDomdecGroups(const CudaDomdec& domdec);
~CudaDomdecGroups();
std::vector<AtomGroupBase*>& get_atomGroupVector() {return atomGroupVector;}
void beginGroups();
//
// Register groups.
// h_groupList[] is the host version of atomGroup.groupList[]
//
template <typename T>
void insertGroup(int id, CudaAtomGroup<T>& atomGroup, T* h_groupList) {
assert(regGroups.size() == domdec.get_ncoord_glo());
int type = atomGroups.size();
int size = T::size();
std::pair<std::map<int, AtomGroupBase*>::iterator, bool> ret =
atomGroups.insert(std::pair<int, AtomGroupBase*>(id, &atomGroup));
if (ret.second == false) {
std::cout << "CudaDomdecGroups::insertGroup, group IDs must be unique" << std::endl;
exit(1);
}
// Set group type
atomGroup.set_type(type);
// Loop through groups
for (int i=0;i < atomGroup.get_numGroupList();i++) {
// Get atoms that are in group
std::vector<int> atoms;
h_groupList[i].getAtoms(atoms);
int t = *std::min_element(atoms.begin(), atoms.end());
regGroups.at(t).push_back((size << 16) | type );
regGroups.at(t).push_back(i);
regGroups.at(t).insert( regGroups.at(t).end(), atoms.begin(), atoms.end() );
}
}
void finishGroups();
void buildGroupTables(cudaStream_t stream=0);
void syncGroupTables(cudaStream_t stream=0);
// Return group list.
// NOTE: This is constant during the run
template <typename T>
T* getGroupList(const int id) {
std::map<int, AtomGroupBase*>::iterator it = atomGroups.find(id);
if (it == atomGroups.end()) return NULL;
CudaAtomGroup<T>* p = dynamic_cast< CudaAtomGroup<T>* >( it->second );
if (p == NULL) {
std::cerr << "CudaDomdecGroups::get_group, dynamic_cast failed" << std::endl;
exit(1);
}
return p->get_groupList();
}
// Return number of groups in list
// NOTE: This is constant during the run
int getNumGroupList(const int id) {
std::map<int, AtomGroupBase*>::iterator it = atomGroups.find(id);
// Group "id" not found => return zero
if (it == atomGroups.end()) return 0;
return it->second->get_numGroupList();
}
// Return group table
// NOTE: This changes at neighborlist update
int* getGroupTable(const int id) {
std::map<int, AtomGroupBase*>::iterator it = atomGroups.find(id);
if (it == atomGroups.end()) return NULL;
return it->second->get_table();
}
// Return number of entries in group table
// NOTE: This changes at neighborlist update
int getNumGroupTable(const int id) {
std::map<int, AtomGroupBase*>::iterator it = atomGroups.find(id);
if (it == atomGroups.end()) return 0;
return it->second->get_numTable();
}
const int* getNeighPos() {return h_neighPos;}
int* getCoordInd() {return coordInd;}
};
#endif // CUDADOMDECGROUPS_H