From 250134726a509a887a21ec301bacbae96358bd8e Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Thu, 31 Mar 2022 16:16:15 -0500 Subject: [PATCH] info: add MPIX_Info_set_hex Add MPIX_Info_set_hex to allow user to provide info hints with binary values. This routine provide a consistent way of encoding binary values into string that implementation can decode. --- src/binding/c/stream_api.txt | 6 ++ src/mpi/stream/stream_impl.c | 97 +++++++++++++++++++---------- test/mpi/impls/mpich/cuda/stream.cu | 17 +---- 3 files changed, 71 insertions(+), 49 deletions(-) diff --git a/src/binding/c/stream_api.txt b/src/binding/c/stream_api.txt index 16d94f33017..6e708f85fe6 100644 --- a/src/binding/c/stream_api.txt +++ b/src/binding/c/stream_api.txt @@ -31,3 +31,9 @@ MPIX_Recv_enqueue: MPIX_Stream_synchronize: stream: STREAM, [stream object] + +MPIX_Info_set_hex: + info: INFO, direction=in, [info object] + key: STRING, constant=True, [key] + value: BUFFER, constant=True, [value] + value_size: INFO_VALUE_LENGTH, [size of value] diff --git a/src/mpi/stream/stream_impl.c b/src/mpi/stream/stream_impl.c index 989b771ba43..d5257ee5de4 100644 --- a/src/mpi/stream/stream_impl.c +++ b/src/mpi/stream/stream_impl.c @@ -28,39 +28,7 @@ int MPIR_Stream_free_impl(MPIR_Stream * stream_ptr) * * Returns MPICH error codes */ -static int hex_val(char c) -{ - if (c >= '0' && c <= '9') { - return c - '0'; - } else if (c >= 'a' && c <= 'f') { - return c - 'a' + 10; - } else if (c >= 'A' && c <= 'F') { - return c - 'A' + 10; - } else { - return -1; - } -} - -static int hex_decode(const char *str, void *buf, int len) -{ - int n = strlen(str); - if (n != len * 2) { - return 1; - } - - unsigned char *s = buf; - for (int i = 0; i < len; i++) { - int a = hex_val(str[i * 2]); - int b = hex_val(str[i * 2 + 1]); - if (a < 0 || b < 0) { - return 1; - } - s[i] = (unsigned char) ((a << 4) + b); - } - - return 0; -} - +static int hex_decode(const char *str, void *buf, int len); int MPIR_Stream_create_impl(MPIR_Info * info_ptr, MPIR_Stream ** p_stream_ptr) { int mpi_errno = MPI_SUCCESS; @@ -341,3 +309,66 @@ int MPIR_Stream_synchronize_impl(MPIR_Stream * stream_ptr) return mpi_errno; } + +static int hex_encode(char *str, const void *value, int len); +int MPIR_Info_set_hex_impl(MPIR_Info * info_ptr, const char *key, const void *value, int value_size) +{ + int mpi_errno = MPI_SUCCESS; + + char value_buf[1024]; + MPIR_Assertp(value_size * 2 + 1 < 1024); + + hex_encode(value_buf, value, value_size); + + mpi_errno = MPIR_Info_set_impl(info_ptr, key, value_buf); + + return mpi_errno; +} + +/* ---- internal utility ---- */ + +static int hex_val(char c) +{ + if (c >= '0' && c <= '9') { + return c - '0'; + } else if (c >= 'a' && c <= 'f') { + return c - 'a' + 10; + } else if (c >= 'A' && c <= 'F') { + return c - 'A' + 10; + } else { + return -1; + } +} + +static int hex_encode(char *str, const void *value, int len) +{ + /* assume the size of str is already validated */ + + const unsigned char *s = value; + + for (int i = 0; i < len; i++) { + sprintf(str + i * 2, "%02x", s[i]); + } + + return 0; +} + +static int hex_decode(const char *str, void *buf, int len) +{ + int n = strlen(str); + if (n != len * 2) { + return 1; + } + + unsigned char *s = buf; + for (int i = 0; i < len; i++) { + int a = hex_val(str[i * 2]); + int b = hex_val(str[i * 2 + 1]); + if (a < 0 || b < 0) { + return 1; + } + s[i] = (unsigned char) ((a << 4) + b); + } + + return 0; +} diff --git a/test/mpi/impls/mpich/cuda/stream.cu b/test/mpi/impls/mpich/cuda/stream.cu index 6f73831bfa1..7943a8421de 100644 --- a/test/mpi/impls/mpich/cuda/stream.cu +++ b/test/mpi/impls/mpich/cuda/stream.cu @@ -47,17 +47,6 @@ void saxpy(int n, float a, float *x, float *y) if (i < n) y[i] = a*x[i] + y[i]; } -static void hex_encode(char *buf, int len, cudaStream_t stream) -{ - int n = sizeof(cudaStream_t); - assert(len >= n * 2 + 1); - - unsigned char *s = (unsigned char *) &stream; - for (int i = 0; i < n; i++) { - sprintf(buf + i * 2, "%02x", s[i]); - } -} - int main(void) { int errs = 0; @@ -89,14 +78,10 @@ int main(void) init_y(y); } - /* hexadecimal encoding */ - char str_stream[100]; - hex_encode(str_stream, 100, stream); - MPI_Info info; MPI_Info_create(&info); MPI_Info_set(info, "type", "cudaStream_t"); - MPI_Info_set(info, "id", str_stream); + MPIX_Info_set_hex(info, "id", &stream, sizeof(stream)); MPIX_Stream mpi_stream; MPIX_Stream_create(info, &mpi_stream);