Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify for getting solutions when setting alpha to 0 with A/B is nullptr #956

Merged
merged 6 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions clients/gtest/auxiliary_gtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ namespace
testing_aux_matmul_alg_init(arg);
else if(!strcmp(arg.function, "aux_get_sol_with_null_biasaddr"))
testing_aux_get_sol_with_null_biasaddr(arg);
else if(!strcmp(arg.function, "aux_get_sol_with_zero_alpha_null_a_b"))
testing_aux_get_sol_with_zero_alpha_null_a_b(arg);
else if(!strcmp(arg.function, "aux_get_sol_with_zero_alpha_null_a_b_ext"))
testing_aux_get_sol_with_zero_alpha_null_a_b_ext(arg);
else if(!strcmp(arg.function, "aux_matmul_alg_get_attr_bad_arg"))
testing_aux_matmul_alg_get_attr_bad_arg(arg);
else if(!strcmp(arg.function, "aux_matmul_plan_init_bad_arg"))
Expand Down Expand Up @@ -153,6 +157,8 @@ namespace
|| !strcmp(arg.function, "aux_matmul_alg_init_bad_arg")
|| !strcmp(arg.function, "aux_matmul_alg_init")
|| !strcmp(arg.function, "aux_get_sol_with_null_biasaddr")
|| !strcmp(arg.function, "aux_get_sol_with_zero_alpha_null_a_b")
|| !strcmp(arg.function, "aux_get_sol_with_zero_alpha_null_a_b_ext")
|| !strcmp(arg.function, "aux_matmul_alg_get_attr_bad_arg")
|| !strcmp(arg.function, "aux_matmul_plan_init_bad_arg")
|| !strcmp(arg.function, "aux_matmul_plan_init")
Expand Down
14 changes: 14 additions & 0 deletions clients/gtest/auxiliary_gtest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,20 @@ Tests:
function:
- aux_get_sol_with_null_biasaddr: *hpa_half_precision

- name: aux_get_sol_with_zero_alpha_null_a_b
category: pre_checkin
function:
- aux_get_sol_with_zero_alpha_null_a_b: *hpa_half_precision
alpha: 0
beta: 1

- name: aux_get_sol_with_zero_alpha_null_a_b_ext
category: pre_checkin
function:
- aux_get_sol_with_zero_alpha_null_a_b_ext: *hpa_half_precision
alpha: 0
beta: 1

- name: aux_matmul_alg_get_attr_bad_arg
category: pre_checkin
function:
Expand Down
156 changes: 156 additions & 0 deletions clients/include/testing_auxiliary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "unit.hpp"
#include "utility.hpp"
#include <hipblaslt/hipblaslt.h>
#include <hipblaslt/hipblaslt-ext.hpp> // Add check for hipblaslt-ext

void testing_aux_handle_init_bad_arg(const Arguments& arg)
{
Expand Down Expand Up @@ -392,6 +393,161 @@ void testing_aux_get_sol_with_null_biasaddr(const Arguments& arg)
CHECK_HIP_ERROR(hipStreamDestroy(stream));
}

// hipBLASLt API: For testing case of (alpha=0 && (A=NULL || B=NULL))
void testing_aux_get_sol_with_zero_alpha_null_a_b(const Arguments& arg)
{
using InTypeA = hipblasLtHalf;
using InTypeB = hipblasLtHalf;
using OutType = hipblasLtHalf;
using AlphaType = hipblasLtFloat;
using BetaType = hipblasLtFloat;

hipStream_t stream;
hipblasLtHandle_t handle;
hipblasOperation_t trans_a = arg.transA == 'N' ? HIPBLAS_OP_N : HIPBLAS_OP_T;
hipblasOperation_t trans_b = arg.transB == 'N' ? HIPBLAS_OP_N : HIPBLAS_OP_T;
int64_t m = arg.M[0];
int64_t n = arg.N[0];
int64_t k = arg.K[0];
int64_t batch_count = 1;
// Setting alpha = 0.
float alpha = 0;
float beta = arg.beta;
// Setting d_a, d_b as nullptr.
void* d_a = NULL;
void* d_b = NULL;
void* d_c;
void* d_d;

CHECK_HIP_ERROR(hipStreamCreate(&stream));
CHECK_HIPBLASLT_ERROR(hipblasLtCreate(&handle));
CHECK_HIP_ERROR(hipMalloc(&d_c, m * n * batch_count * sizeof(OutType)));
CHECK_HIP_ERROR(hipMalloc(&d_d, m * n * batch_count * sizeof(OutType)));

hipblasLtMatrixLayout_t matA, matB, matC, matD;
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matA, arg.a_type, m, k, m));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matB, arg.a_type, k, n, k));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matC, arg.a_type, m, n, m));
CHECK_HIPBLASLT_ERROR(hipblasLtMatrixLayoutCreate(&matD, arg.a_type, m, n, m));

hipblasLtMatmulDesc_t matmul;
CHECK_HIPBLASLT_ERROR(
hipblasLtMatmulDescCreate(&matmul, arg.compute_type, arg.scale_type));
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &trans_a, sizeof(int32_t)));
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(int32_t)));

hipblasLtEpilogue_t epilogue = HIPBLASLT_EPILOGUE_DEFAULT;
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_EPILOGUE, &epilogue,sizeof(epilogue)));

// Set User Preference attributes
hipblasLtMatmulPreference_t pref;
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulPreferenceCreate(&pref));

const int request_solutions = 1;
hipblasLtMatmulHeuristicResult_t heuristicResult[request_solutions];
int returnedAlgoCount = 0;
CHECK_HIPBLASLT_ERROR(hipblasLtMatmulAlgoGetHeuristic(handle,
matmul,
matA,
matB,
matC,
matD,
pref,
request_solutions,
heuristicResult,
&returnedAlgoCount));

CHECK_SOLUTION_FOUND(returnedAlgoCount);

// Validation for solution running.
CHECK_HIPBLASLT_ERROR(hipblasLtMatmul(handle,
matmul,
&alpha,
d_a,
matA,
d_b,
matB,
&beta,
d_c,
matC,
d_d,
matD,
&heuristicResult[0].algo,
nullptr,
0,
stream));

CHECK_HIP_ERROR(hipFree(d_c));
CHECK_HIP_ERROR(hipFree(d_d));
CHECK_HIPBLASLT_ERROR(hipblasLtDestroy(handle));
CHECK_HIP_ERROR(hipStreamDestroy(stream));
}

// hipBLASLtExt API: For testing case of (alpha=0 && (A=NULL || B=NULL))
void testing_aux_get_sol_with_zero_alpha_null_a_b_ext(const Arguments& arg)
{
using InTypeA = hipblasLtHalf;
using InTypeB = hipblasLtHalf;
using OutType = hipblasLtHalf;
using AlphaType = hipblasLtFloat;
using BetaType = hipblasLtFloat;

hipStream_t stream;
hipblasLtHandle_t handle;
hipblasOperation_t trans_a = arg.transA == 'N' ? HIPBLAS_OP_N : HIPBLAS_OP_T;
hipblasOperation_t trans_b = arg.transB == 'N' ? HIPBLAS_OP_N : HIPBLAS_OP_T;
int64_t m = arg.M[0];
int64_t n = arg.N[0];
int64_t k = arg.K[0];
int64_t batch_count = 1;
// Setting alpha = 0.
float alpha = 0;
float beta = arg.beta;
// Setting d_a, d_b as nullptr.
void* d_a = NULL;
void* d_b = NULL;
void* d_c;
void* d_d;

CHECK_HIP_ERROR(hipStreamCreate(&stream));
CHECK_HIPBLASLT_ERROR(hipblasLtCreate(&handle));
CHECK_HIP_ERROR(hipMalloc(&d_c, m * n * batch_count * sizeof(OutType)));
CHECK_HIP_ERROR(hipMalloc(&d_d, m * n * batch_count * sizeof(OutType)));

hipblaslt_ext::GemmPreference gemmPref;
hipblaslt_ext::Gemm gemm(
handle, trans_a, trans_b, arg.a_type, arg.a_type, arg.a_type, arg.a_type, arg.compute_type);

hipblaslt_ext::GemmEpilogue
epilogue; // No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. (Gemm only)
hipblaslt_ext::GemmInputs inputs;
inputs.a = d_a;
inputs.b = d_b;
inputs.c = d_c;
inputs.d = d_d;
inputs.alpha = &alpha;
inputs.beta = &beta;
gemm.setProblem(m, n, k, batch_count, epilogue, inputs);

const int request_solutions = 1;
std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult;
CHECK_HIPBLASLT_ERROR(gemm.algoGetHeuristic(request_solutions, gemmPref, heuristicResult));
CHECK_SOLUTION_FOUND(heuristicResult.size());

// Make sure to initialize every time when algo changes
CHECK_HIPBLASLT_ERROR(gemm.initialize(heuristicResult[0].algo, nullptr));
// Validation for solution running.
CHECK_HIPBLASLT_ERROR(gemm.run(stream));

CHECK_HIP_ERROR(hipFree(d_c));
CHECK_HIP_ERROR(hipFree(d_d));
CHECK_HIPBLASLT_ERROR(hipblasLtDestroy(handle));
CHECK_HIP_ERROR(hipStreamDestroy(stream));
}

void testing_aux_matmul_alg_get_attr_bad_arg(const Arguments& arg) {}

void testing_aux_matmul_alg_null_matmul(const Arguments& arg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ inline rocblaslt_status validateMatmulArgs(int64_t m,
return rocblaslt_status_invalid_pointer;

// pointers must be valid
if(n && ((k && (!a || !b || !alpha)) || !c || !d))
// Update for the valid case: (alpha=0 && (A=NULL || B=NULL))
if(n && ((k && (!alpha || ((*((float*)alpha)) && (!a || !b)))) || !c || !d))
return rocblaslt_status_invalid_pointer;

return rocblaslt_status_continue;
Expand Down
5 changes: 3 additions & 2 deletions library/src/amd_detail/rocblaslt/src/rocblaslt_mat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1063,8 +1063,9 @@ rocblaslt_status rocblaslt_matmul(rocblaslt_handle handle,
}

// Check if pointer is valid
if(alpha == nullptr || beta == nullptr || A == nullptr || B == nullptr || C == nullptr
|| D == nullptr)
// Update for the valid case: (alpha=0 && (A=NULL || B=NULL))
if(alpha == nullptr || beta == nullptr || C == nullptr || D == nullptr
|| ((*((float*)alpha)) && (A == nullptr || B == nullptr)))
{
log_error(__func__, "invalid data pointer");
return rocblaslt_status_invalid_pointer;
Expand Down
5 changes: 3 additions & 2 deletions library/src/amd_detail/rocblaslt/src/tensile_host.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1482,8 +1482,9 @@ rocblaslt_status gemmCreate(RocblasltContractionProblem const& problem,
try
{
// Check if pointer is valid
if(problem.alpha == nullptr || problem.beta == nullptr || problem.A == nullptr
|| problem.B == nullptr || problem.C == nullptr || problem.D == nullptr)
// Update for the valid case: (alpha=0 && (A=NULL || B=NULL))
if(problem.alpha == nullptr || problem.beta == nullptr || problem.C == nullptr || problem.D == nullptr
|| ((*((float*)problem.alpha)) && (problem.A == nullptr || problem.B == nullptr)))
{
log_error(__func__, "invalid data pointer");
return rocblaslt_status_invalid_pointer;
Expand Down