Skip to content

Commit

Permalink
gpu - use cached work vectors across operators
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremylt committed Sep 26, 2024
1 parent 85792eb commit 9bb9298
Show file tree
Hide file tree
Showing 8 changed files with 455 additions and 343 deletions.
382 changes: 222 additions & 160 deletions backends/cuda-ref/ceed-cuda-ref-operator.c

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion backends/cuda-ref/ceed-cuda-ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ typedef struct {
} CeedOperatorAssemble_Cuda;

typedef struct {
bool *skip_rstr_in, *skip_rstr_out, *apply_add_basis_out, has_shared_e_vecs;
bool *skip_rstr_in, *skip_rstr_out, *apply_add_basis_out;
uint64_t *input_states; // State tracking for passive inputs
CeedVector *e_vecs_in, *e_vecs_out;
CeedVector *q_vecs_in, *q_vecs_out;
Expand Down
392 changes: 226 additions & 166 deletions backends/hip-ref/ceed-hip-ref-operator.c

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion backends/hip-ref/ceed-hip-ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ typedef struct {
} CeedOperatorAssemble_Hip;

typedef struct {
bool *skip_rstr_in, *skip_rstr_out, *apply_add_basis_out, has_shared_e_vecs;
bool *skip_rstr_in, *skip_rstr_out, *apply_add_basis_out;
uint64_t *input_states; // State tracking for passive inputs
CeedVector *e_vecs_in, *e_vecs_out;
CeedVector *q_vecs_in, *q_vecs_out;
Expand Down
10 changes: 0 additions & 10 deletions interface/ceed-basis.c
Original file line number Diff line number Diff line change
Expand Up @@ -331,11 +331,6 @@ static int CeedBasisApplyAtPointsCheckDims(CeedBasis basis, CeedInt num_elem, co
if (x_ref != CEED_VECTOR_NONE) CeedCall(CeedVectorGetLength(x_ref, &x_length));
if (u != CEED_VECTOR_NONE) CeedCall(CeedVectorGetLength(u, &u_length));

// Check compatibility of topological and geometrical dimensions
CeedCheck((t_mode == CEED_TRANSPOSE && v_length % num_nodes == 0) || (t_mode == CEED_NOTRANSPOSE && u_length % num_nodes == 0) ||
(eval_mode == CEED_EVAL_WEIGHT),
ceed, CEED_ERROR_DIMENSION, "Length of input/output vectors incompatible with basis dimensions and number of points");

// Check compatibility coordinates vector
for (CeedInt i = 0; i < num_elem; i++) total_num_points += num_points[i];
CeedCheck((x_length >= total_num_points * dim) || (eval_mode == CEED_EVAL_WEIGHT), ceed, CEED_ERROR_DIMENSION,
Expand Down Expand Up @@ -1819,11 +1814,6 @@ static int CeedBasisApplyCheckDims(CeedBasis basis, CeedInt num_elem, CeedTransp
CeedCall(CeedVectorGetLength(v, &v_length));
if (u) CeedCall(CeedVectorGetLength(u, &u_length));

// Check compatibility of topological and geometrical dimensions
CeedCheck((t_mode == CEED_TRANSPOSE && v_length % num_nodes == 0 && u_length % num_qpts == 0) ||
(t_mode == CEED_NOTRANSPOSE && u_length % num_nodes == 0 && v_length % num_qpts == 0),
ceed, CEED_ERROR_DIMENSION, "Length of input/output vectors incompatible with basis dimensions");

// Check vector lengths to prevent out of bounds issues
bool has_good_dims = true;
switch (eval_mode) {
Expand Down
2 changes: 1 addition & 1 deletion interface/ceed-vector.c
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,7 @@ int CeedVectorPointwiseMult(CeedVector w, CeedVector x, CeedVector y) {
CeedCall(CeedVectorGetLength(w, &length_w));
CeedCall(CeedVectorGetLength(x, &length_x));
CeedCall(CeedVectorGetLength(y, &length_y));
CeedCheck(length_w == length_x && length_w == length_y, ceed, CEED_ERROR_UNSUPPORTED,
CeedCheck(length_x >= length_x && length_y >= length_w, ceed, CEED_ERROR_UNSUPPORTED,
"Cannot multiply vectors of different lengths."
" x length: %" CeedSize_FMT " y length: %" CeedSize_FMT,
length_x, length_y);
Expand Down
2 changes: 1 addition & 1 deletion tests/junit.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def check_required_failure(self, test: str, spec: TestSpec, resource: str, stder
elif test_id in ['t215']:
fail_str = 'Cannot destroy CeedElemRestriction, a process has read access to the offset data'
elif test_id in ['t303']:
fail_str = 'Length of input/output vectors incompatible with basis dimensions'
fail_str = 'Input/output vectors too short for basis and evaluation mode'
elif test_id in ['t408']:
fail_str = 'CeedQFunctionContextGetData(): Cannot grant CeedQFunctionContext data access, a process has read access'
elif test_id in ['t409'] and contains_any(resource, ['memcheck']):
Expand Down
6 changes: 3 additions & 3 deletions tests/t303-basis.c
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/// @file
/// Test checking BasisApply input/output vectors compatibility with basis dimensions
/// \test Test checking BasisApply input/output vectors compatibility with basis dimensions
/// Test checking BasisApply input/output vectors compatibility with basis
/// \test Test checking BasisApply input/output vectors compatibility with basis

//TESTARGS(only="cpu") {ceed_resource}
#include <ceed.h>
Expand All @@ -15,7 +15,7 @@ int main(int argc, char **argv) {
CeedInit(argv[1], &ceed);

CeedVectorCreate(ceed, len, &u);
CeedVectorCreate(ceed, len + 1, &v);
CeedVectorCreate(ceed, len - 1, &v);

CeedBasisCreateTensorH1Lagrange(ceed, dim, num_comp, p, q, CEED_GAUSS, &basis);

Expand Down

0 comments on commit 9bb9298

Please sign in to comment.