diff --git a/backends/cuda-ref/ceed-cuda-ref-basis.c b/backends/cuda-ref/ceed-cuda-ref-basis.c index c245a51489..1c38ce002c 100644 --- a/backends/cuda-ref/ceed-cuda-ref-basis.c +++ b/backends/cuda-ref/ceed-cuda-ref-basis.c @@ -40,9 +40,14 @@ static int CeedBasisApplyCore_Cuda(CeedBasis basis, bool apply_add, const CeedIn // Clear v for transpose operation if (is_transpose && !apply_add) { + CeedInt num_comp, q_comp, num_nodes, num_qpts; CeedSize length; - CeedCallBackend(CeedVectorGetLength(v, &length)); + CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); + CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp)); + CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes)); + CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &num_qpts)); + length = (CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)num_qpts * (CeedSize)q_comp)); CeedCallCuda(ceed, cudaMemset(d_v, 0, length * sizeof(CeedScalar))); } CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d)); @@ -206,9 +211,14 @@ static int CeedBasisApplyAtPointsCore_Cuda(CeedBasis basis, bool apply_add, cons // Clear v for transpose operation if (is_transpose && !apply_add) { + CeedInt num_comp, q_comp, num_nodes; CeedSize length; - CeedCallBackend(CeedVectorGetLength(v, &length)); + CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); + CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp)); + CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes)); + length = + (CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)max_num_points * (CeedSize)q_comp)); CeedCallCuda(ceed, cudaMemset(d_v, 0, length * sizeof(CeedScalar))); } @@ -283,9 +293,12 @@ static int CeedBasisApplyNonTensorCore_Cuda(CeedBasis basis, bool apply_add, con // Clear v for transpose operation if (is_transpose && !apply_add) { + CeedInt num_comp, q_comp; CeedSize length; - CeedCallBackend(CeedVectorGetLength(v, &length)); + CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); + CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp)); + length = (CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)num_qpts * (CeedSize)q_comp)); CeedCallCuda(ceed, cudaMemset(d_v, 0, length * sizeof(CeedScalar))); } diff --git a/backends/cuda-shared/ceed-cuda-shared-basis.c b/backends/cuda-shared/ceed-cuda-shared-basis.c index 8924ff52f4..c947846f3a 100644 --- a/backends/cuda-shared/ceed-cuda-shared-basis.c +++ b/backends/cuda-shared/ceed-cuda-shared-basis.c @@ -312,9 +312,14 @@ static int CeedBasisApplyAtPointsCore_Cuda_shared(CeedBasis basis, bool apply_ad // Clear v for transpose operation if (is_transpose && !apply_add) { + CeedInt num_comp, q_comp, num_nodes; CeedSize length; - CeedCallBackend(CeedVectorGetLength(v, &length)); + CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); + CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp)); + CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes)); + length = + (CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)max_num_points * (CeedSize)q_comp)); CeedCallCuda(ceed, cudaMemset(d_v, 0, length * sizeof(CeedScalar))); } diff --git a/backends/hip-ref/ceed-hip-ref-basis.c b/backends/hip-ref/ceed-hip-ref-basis.c index 70dda0a7da..f54184f28d 100644 --- a/backends/hip-ref/ceed-hip-ref-basis.c +++ b/backends/hip-ref/ceed-hip-ref-basis.c @@ -39,9 +39,14 @@ static int CeedBasisApplyCore_Hip(CeedBasis basis, bool apply_add, const CeedInt // Clear v for transpose operation if (is_transpose && !apply_add) { + CeedInt num_comp, q_comp, num_nodes, num_qpts; CeedSize length; - CeedCallBackend(CeedVectorGetLength(v, &length)); + CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); + CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp)); + CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes)); + CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis, &num_qpts)); + length = (CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)num_qpts * (CeedSize)q_comp)); CeedCallHip(ceed, hipMemset(d_v, 0, length * sizeof(CeedScalar))); } CeedCallBackend(CeedBasisGetNumQuadraturePoints1D(basis, &Q_1d)); @@ -204,9 +209,14 @@ static int CeedBasisApplyAtPointsCore_Hip(CeedBasis basis, bool apply_add, const // Clear v for transpose operation if (is_transpose && !apply_add) { + CeedInt num_comp, q_comp, num_nodes; CeedSize length; - CeedCallBackend(CeedVectorGetLength(v, &length)); + CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); + CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp)); + CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes)); + length = + (CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)max_num_points * (CeedSize)q_comp)); CeedCallHip(ceed, hipMemset(d_v, 0, length * sizeof(CeedScalar))); } diff --git a/backends/hip-shared/ceed-hip-shared-basis.c b/backends/hip-shared/ceed-hip-shared-basis.c index 307107ec6b..05b564e7f2 100644 --- a/backends/hip-shared/ceed-hip-shared-basis.c +++ b/backends/hip-shared/ceed-hip-shared-basis.c @@ -371,9 +371,14 @@ static int CeedBasisApplyAtPointsCore_Hip_shared(CeedBasis basis, bool apply_add // Clear v for transpose operation if (is_transpose && !apply_add) { + CeedInt num_comp, q_comp, num_nodes; CeedSize length; - CeedCallBackend(CeedVectorGetLength(v, &length)); + CeedCallBackend(CeedBasisGetNumComponents(basis, &num_comp)); + CeedCallBackend(CeedBasisGetNumQuadratureComponents(basis, eval_mode, &q_comp)); + CeedCallBackend(CeedBasisGetNumNodes(basis, &num_nodes)); + length = + (CeedSize)num_elem * (CeedSize)num_comp * (t_mode == CEED_TRANSPOSE ? (CeedSize)num_nodes : ((CeedSize)max_num_points * (CeedSize)q_comp)); CeedCallHip(ceed, hipMemset(d_v, 0, length * sizeof(CeedScalar))); } diff --git a/interface/ceed-basis.c b/interface/ceed-basis.c index aa19489e0a..c6869f2f3b 100644 --- a/interface/ceed-basis.c +++ b/interface/ceed-basis.c @@ -333,10 +333,10 @@ static int CeedBasisApplyAtPointsCheckDims(CeedBasis basis, CeedInt num_elem, co // Check compatibility coordinates vector for (CeedInt i = 0; i < num_elem; i++) total_num_points += num_points[i]; - CeedCheck((x_length >= total_num_points * dim) || (eval_mode == CEED_EVAL_WEIGHT), ceed, CEED_ERROR_DIMENSION, + CeedCheck((x_length >= (CeedSize)total_num_points * (CeedSize)dim) || (eval_mode == CEED_EVAL_WEIGHT), ceed, CEED_ERROR_DIMENSION, "Length of reference coordinate vector incompatible with basis dimension and number of points." " Found reference coordinate vector of length %" CeedSize_FMT ", not of length %" CeedSize_FMT ".", - x_length, total_num_points * dim); + x_length, (CeedSize)total_num_points * (CeedSize)dim); // Check CEED_EVAL_WEIGHT only on CEED_NOTRANSPOSE CeedCheck(eval_mode != CEED_EVAL_WEIGHT || t_mode == CEED_NOTRANSPOSE, ceed, CEED_ERROR_UNSUPPORTED, @@ -346,13 +346,16 @@ static int CeedBasisApplyAtPointsCheckDims(CeedBasis basis, CeedInt num_elem, co bool has_good_dims = true; switch (eval_mode) { case CEED_EVAL_INTERP: - has_good_dims = ((t_mode == CEED_TRANSPOSE && (u_length >= total_num_points * num_q_comp || v_length >= num_elem * num_nodes * num_comp)) || - (t_mode == CEED_NOTRANSPOSE && (v_length >= total_num_points * num_q_comp || u_length >= num_elem * num_nodes * num_comp))); + has_good_dims = ((t_mode == CEED_TRANSPOSE && (u_length >= (CeedSize)total_num_points * (CeedSize)num_q_comp || + v_length >= (CeedSize)num_elem * (CeedSize)num_nodes * (CeedSize)num_comp)) || + (t_mode == CEED_NOTRANSPOSE && (v_length >= (CeedSize)total_num_points * (CeedSize)num_q_comp || + u_length >= (CeedSize)num_elem * (CeedSize)num_nodes * (CeedSize)num_comp))); break; case CEED_EVAL_GRAD: - has_good_dims = - ((t_mode == CEED_TRANSPOSE && (u_length >= total_num_points * num_q_comp * dim || v_length >= num_elem * num_nodes * num_comp)) || - (t_mode == CEED_NOTRANSPOSE && (v_length >= total_num_points * num_q_comp * dim || u_length >= num_elem * num_nodes * num_comp))); + has_good_dims = ((t_mode == CEED_TRANSPOSE && (u_length >= (CeedSize)total_num_points * (CeedSize)num_q_comp * (CeedSize)dim || + v_length >= (CeedSize)num_elem * (CeedSize)num_nodes * (CeedSize)num_comp)) || + (t_mode == CEED_NOTRANSPOSE && (v_length >= (CeedSize)total_num_points * (CeedSize)num_q_comp * (CeedSize)dim || + u_length >= (CeedSize)num_elem * (CeedSize)num_nodes * (CeedSize)num_comp))); break; case CEED_EVAL_WEIGHT: has_good_dims = t_mode == CEED_NOTRANSPOSE && (v_length >= total_num_points); @@ -1822,12 +1825,13 @@ static int CeedBasisApplyCheckDims(CeedBasis basis, CeedInt num_elem, CeedTransp case CEED_EVAL_GRAD: case CEED_EVAL_DIV: case CEED_EVAL_CURL: - has_good_dims = - ((t_mode == CEED_TRANSPOSE && u_length >= num_elem * num_comp * num_qpts * q_comp && v_length >= num_elem * num_comp * num_nodes) || - (t_mode == CEED_NOTRANSPOSE && v_length >= num_elem * num_qpts * num_comp * q_comp && u_length >= num_elem * num_comp * num_nodes)); + has_good_dims = ((t_mode == CEED_TRANSPOSE && u_length >= (CeedSize)num_elem * (CeedSize)num_comp * (CeedSize)num_qpts * (CeedSize)q_comp && + v_length >= (CeedSize)num_elem * (CeedSize)num_comp * (CeedSize)num_nodes) || + (t_mode == CEED_NOTRANSPOSE && v_length >= (CeedSize)num_elem * (CeedSize)num_qpts * (CeedSize)num_comp * (CeedSize)q_comp && + u_length >= (CeedSize)num_elem * (CeedSize)num_comp * (CeedSize)num_nodes)); break; case CEED_EVAL_WEIGHT: - has_good_dims = v_length >= num_elem * num_qpts; + has_good_dims = v_length >= (CeedSize)num_elem * (CeedSize)num_qpts; break; } CeedCheck(has_good_dims, ceed, CEED_ERROR_DIMENSION, "Input/output vectors too short for basis and evaluation mode");