Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First pass at scan, implemented linear and ring algorithms #1154

Merged
merged 13 commits into from
Jan 21, 2025
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ jobs:
libfabric_version: v1.13.x
- config_name: ring reduce algorithm
env_setup: export SHMEM_REDUCE_ALGORITHM=ring
export SHMEM_SCAN_ALGORITHM=ring
sos_config: --enable-error-checking --enable-pmi-simple
libfabric_version: v1.13.x
- config_name: ring fcollect algorithm, tx/rx single poll limit
Expand Down
27 changes: 27 additions & 0 deletions mpp/shmemx.h4
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,20 @@ static inline void shmemx_ibget(shmem_ctx_t ctx, $2 *target, const $2 *source,
}')dnl
SHMEM_CXX_DEFINE_FOR_RMA(`SHMEM_CXX_IBGET')

define(`SHMEM_CXX_SUM_EXSCAN',
`static inline int shmemx_sum_exscan(shmem_team_t team, $2* dest, const $2* source,
size_t nelems) {
return shmemx_$1_sum_exscan(team, dest, source, nelems);
}')dnl
SHMEM_CXX_DEFINE_FOR_COLL_SUM_PROD(`SHMEM_CXX_SUM_EXSCAN')

define(`SHMEM_CXX_SUM_INSCAN',
`static inline int shmemx_sum_inscan(shmem_team_t team, $2* dest, const $2* source,
size_t nelems) {
return shmemx_$1_sum_inscan(team, dest, source, nelems);
}')dnl
SHMEM_CXX_DEFINE_FOR_COLL_SUM_PROD(`SHMEM_CXX_SUM_INSCAN')

/* C11 Generic Macros */
#elif (defined(__STDC_VERSION__) && __STDC_VERSION__ >= 201112L && !defined(SHMEM_INTERNAL_INCLUDE))

Expand Down Expand Up @@ -105,6 +119,19 @@ SHMEM_BIND_C11_RMA(`SHMEM_C11_GEN_IBGET', `, \') \
uint64_t*: shmemx_signal_add \
)(__VA_ARGS__)

define(`SHMEM_C11_GEN_EXSCAN', ` $2*: shmemx_$1_sum_exscan')dnl
#define shmemx_sum_exscan(...) \
_Generic(SHMEM_C11_TYPE_EVAL_PTR(SHMEM_C11_ARG1(__VA_ARGS__)), \
SHMEM_BIND_C11_COLL_SUM_PROD(`SHMEM_C11_GEN_EXSCAN', `, \') \
)(__VA_ARGS__)

define(`SHMEM_C11_GEN_INSCAN', ` $2*: shmemx_$1_sum_inscan')dnl
#define shmemx_sum_inscan(...) \
_Generic(SHMEM_C11_TYPE_EVAL_PTR(SHMEM_C11_ARG1(__VA_ARGS__)), \
SHMEM_BIND_C11_COLL_SUM_PROD(`SHMEM_C11_GEN_INSCAN', `, \') \
)(__VA_ARGS__)


#endif /* C11 */

#endif /* SHMEMX_H */
10 changes: 10 additions & 0 deletions mpp/shmemx_c_func.h4
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,16 @@ SH_PAD(`$1') ptrdiff_t tst, ptrdiff_t sst,
SH_PAD(`$1') size_t bsize, size_t nblocks, int pe)')dnl
SHMEM_DECLARE_FOR_SIZES(`SHMEM_C_CTX_IBGET_N')

define(`SHMEM_C_EXSCAN',
`SHMEM_FUNCTION_ATTRIBUTES int SHPRE()shmemx_$1_$4_exscan(shmem_team_t team, $2 *dest, const $2 *source, size_t nelems);')dnl

SHMEM_BIND_C_COLL_SUM_PROD(`SHMEM_C_EXSCAN', `sum')

define(`SHMEM_C_INSCAN',
`SHMEM_FUNCTION_ATTRIBUTES int SHPRE()shmemx_$1_$4_inscan(shmem_team_t team, $2 *dest, const $2 *source, size_t nelems);')dnl

SHMEM_BIND_C_COLL_SUM_PROD(`SHMEM_C_INSCAN', `sum')

/* Performance Counter Query Routines */
SHMEM_FUNCTION_ATTRIBUTES void SHPRE()shmemx_pcntr_get_issued_write(shmem_ctx_t ctx, uint64_t *cntr_value);
SHMEM_FUNCTION_ATTRIBUTES void SHPRE()shmemx_pcntr_get_issued_read(shmem_ctx_t ctx, uint64_t *cntr_value);
Expand Down
240 changes: 238 additions & 2 deletions src/collectives.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
coll_type_t shmem_internal_barrier_type = AUTO;
coll_type_t shmem_internal_bcast_type = AUTO;
coll_type_t shmem_internal_reduce_type = AUTO;
coll_type_t shmem_internal_scan_type = AUTO;
coll_type_t shmem_internal_collect_type = AUTO;
coll_type_t shmem_internal_fcollect_type = AUTO;
long *shmem_internal_barrier_all_psync;
Expand Down Expand Up @@ -206,6 +207,18 @@ shmem_internal_collectives_init(void)
} else {
RAISE_WARN_MSG("Ignoring bad reduction algorithm '%s'\n", type);
}
}
if (shmem_internal_params.SCAN_ALGORITHM_provided) {
type = shmem_internal_params.SCAN_ALGORITHM;
if (0 == strcmp(type, "auto")) {
shmem_internal_scan_type = AUTO;
} else if (0 == strcmp(type, "linear")) {
shmem_internal_scan_type = LINEAR;
} else if (0 == strcmp(type, "ring")) {
shmem_internal_scan_type = RING;
} else {
RAISE_WARN_MSG("Ignoring bad scan algorithm '%s'\n", type);
}
}
if (shmem_internal_params.COLLECT_ALGORITHM_provided) {
type = shmem_internal_params.COLLECT_ALGORITHM;
Expand Down Expand Up @@ -613,7 +626,7 @@ shmem_internal_op_to_all_linear(void *target, const void *source, size_t count,
SHMEM_WAIT_UNTIL(pSync, SHMEM_CMP_EQ, 0);

/* send data, ack, and wait for completion */
shmem_internal_atomicv(SHMEM_CTX_DEFAULT, target, source, count * type_size,
shmem_internal_atomicv(SHMEM_CTX_DEFAULT, target, source, count, type_size,
PE_start, op, datatype, &completion);
shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion);
shmem_internal_fence(SHMEM_CTX_DEFAULT);
Expand Down Expand Up @@ -819,7 +832,7 @@ shmem_internal_op_to_all_tree(void *target, const void *source, size_t count, si
/* send data, ack, and wait for completion */
shmem_internal_atomicv(SHMEM_CTX_DEFAULT, target,
(num_children == 0) ? source : target,
count * type_size, parent,
count, type_size, parent,
op, datatype, &completion);
shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion);
shmem_internal_fence(SHMEM_CTX_DEFAULT);
Expand Down Expand Up @@ -971,6 +984,229 @@ shmem_internal_op_to_all_recdbl_sw(void *target, const void *source, size_t coun
}


/*****************************************
*
* SCAN
*
*****************************************/
void
shmem_internal_scan_linear(void *target, const void *source, size_t count, size_t type_size,
int PE_start, int PE_stride, int PE_size, void *pWrk, long *pSync,
shm_internal_op_t op, shm_internal_datatype_t datatype, int scantype)
{

/* scantype is 0 for inscan and 1 for exscan */
long zero = 0, one = 1;
long completion = 0;
int free_source = 0;


if (count == 0) return;

int pe, i;

/* In-place scan: copy source data to a temporary buffer so we can use
* the symmetric buffer to accumulate scan data. */
if (target == source) {
void *tmp = malloc(count * type_size);

if (NULL == tmp)
RAISE_ERROR_MSG("Unable to allocate %zub temporary buffer\n", count*type_size);

shmem_internal_copy_self(tmp, target, count * type_size);
free_source = 1;
source = tmp;

shmem_internal_sync(PE_start, PE_stride, PE_size, pSync + 2);
}

if (PE_start == shmem_internal_my_pe) {


/* Initialize target buffer. The put will flush any atomic cache
* value that may currently exist. */
if (scantype)
{
/* Exclude own value for EXSCAN */
//Create an array of size (count * type_size) of zeroes
uint8_t *zeroes = (uint8_t *) calloc(count, type_size);
shmem_internal_put_nb(SHMEM_CTX_DEFAULT, target, zeroes, count * type_size,
shmem_internal_my_pe, &completion);
shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion);
shmem_internal_quiet(SHMEM_CTX_DEFAULT);
free(zeroes);
}

/* Send contribution to all */
for (pe = PE_start + PE_stride*scantype, i = scantype ;
i < PE_size ;
i++, pe += PE_stride) {

shmem_internal_put_nb(SHMEM_CTX_DEFAULT, target, source, count * type_size,
pe, &completion);
shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion);
shmem_internal_fence(SHMEM_CTX_DEFAULT);

}

for (pe = PE_start + PE_stride, i = 1 ;
i < PE_size ;
i++, pe += PE_stride) {
shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &one, sizeof(one), pe);
}

/* Wait for others to acknowledge initialization */
SHMEM_WAIT_UNTIL(pSync, SHMEM_CMP_EQ, PE_size - 1);

/* reset pSync */
shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &zero, sizeof(zero), shmem_internal_my_pe);
SHMEM_WAIT_UNTIL(pSync, SHMEM_CMP_EQ, 0);


/* Let everyone know sending can start */
for (pe = PE_start + PE_stride, i = 1 ;
i < PE_size ;
i++, pe += PE_stride) {
shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &one, sizeof(one), pe);
}
} else {

/* wait for clear to intialization */
SHMEM_WAIT(pSync, 0);

/* reset pSync */
shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &zero, sizeof(zero), shmem_internal_my_pe);
SHMEM_WAIT_UNTIL(pSync, SHMEM_CMP_EQ, 0);

/* Send contribution to all pes larger than itself */
for (pe = shmem_internal_my_pe + PE_stride*scantype, i = shmem_internal_my_pe + scantype ;
i < PE_size;
i++, pe += PE_stride) {

shmem_internal_atomicv(SHMEM_CTX_DEFAULT, target, source, count, type_size,
pe, op, datatype, &completion);
shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion);
shmem_internal_fence(SHMEM_CTX_DEFAULT);

}

shmem_internal_atomic(SHMEM_CTX_DEFAULT, pSync, &one, sizeof(one),
PE_start, SHM_INTERNAL_SUM, SHM_INTERNAL_LONG);

SHMEM_WAIT(pSync, 0);

/* reset pSync */
shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &zero, sizeof(zero), shmem_internal_my_pe);
SHMEM_WAIT_UNTIL(pSync, SHMEM_CMP_EQ, 0);

}

if (free_source)
free((void *)source);

}


void
shmem_internal_scan_ring(void *target, const void *source, size_t count, size_t type_size,
int PE_start, int PE_stride, int PE_size, void *pWrk, long *pSync,
shm_internal_op_t op, shm_internal_datatype_t datatype, int scantype)
{

/* scantype is 0 for inscan and 1 for exscan */
long zero = 0, one = 1;
long completion = 0;
int free_source = 0;

/* In-place scan: copy source data to a temporary buffer so we can use
* the symmetric buffer to accumulate scan data. */
if (target == source) {
void *tmp = malloc(count * type_size);

if (NULL == tmp)
RAISE_ERROR_MSG("Unable to allocate %zub temporary buffer\n", count*type_size);

shmem_internal_copy_self(tmp, target, count * type_size);
free_source = 1;
source = tmp;

shmem_internal_sync(PE_start, PE_stride, PE_size, pSync + 2);
}


avincigu marked this conversation as resolved.
Show resolved Hide resolved
if (count == 0) return;

int pe, i;

if (PE_start == shmem_internal_my_pe) {

/* Initialize target buffer. The put will flush any atomic cache
* value that may currently exist. */
if (scantype)
{
/* Exclude own value for EXSCAN */
//Create an array of size (count * type_size) of zeroes
uint8_t *zeroes = (uint8_t *) calloc(count, type_size);
shmem_internal_put_nb(SHMEM_CTX_DEFAULT, target, zeroes, count * type_size,
shmem_internal_my_pe, &completion);
shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion);
shmem_internal_quiet(SHMEM_CTX_DEFAULT);
free(zeroes);
}

/* Send contribution to all */
for (pe = PE_start + PE_stride*scantype, i = scantype ;
i < PE_size ;
i++, pe += PE_stride) {

shmem_internal_put_nb(SHMEM_CTX_DEFAULT, target, source, count * type_size,
pe, &completion);
shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion);
shmem_internal_fence(SHMEM_CTX_DEFAULT);
}

/* Let next pe know that it's safe to send to us */
if(shmem_internal_my_pe + PE_stride < PE_size)
shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &one, sizeof(one), shmem_internal_my_pe + PE_stride);

/* Wait for others to acknowledge sending data */
SHMEM_WAIT_UNTIL(pSync, SHMEM_CMP_EQ, PE_size - 1);

/* reset pSync */
shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &zero, sizeof(zero), shmem_internal_my_pe);
SHMEM_WAIT_UNTIL(pSync, SHMEM_CMP_EQ, 0);

} else {
/* wait for clear to send */
SHMEM_WAIT(pSync, 0);

/* reset pSync */
shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &zero, sizeof(zero), shmem_internal_my_pe);
SHMEM_WAIT_UNTIL(pSync, SHMEM_CMP_EQ, 0);

/* Send contribution to all pes larger than itself */
for (pe = shmem_internal_my_pe + PE_stride*scantype, i = shmem_internal_my_pe + scantype ;
i < PE_size;
i++, pe += PE_stride) {

shmem_internal_atomicv(SHMEM_CTX_DEFAULT, target, source, count, type_size,
pe, op, datatype, &completion);
shmem_internal_put_wait(SHMEM_CTX_DEFAULT, &completion);
shmem_internal_fence(SHMEM_CTX_DEFAULT);
}

/* Let next pe know that it's safe to send to us */
if (shmem_internal_my_pe + PE_stride < PE_size)
shmem_internal_put_scalar(SHMEM_CTX_DEFAULT, pSync, &one, sizeof(one), shmem_internal_my_pe + PE_stride);

shmem_internal_atomic(SHMEM_CTX_DEFAULT, pSync, &one, sizeof(one),
PE_start, SHM_INTERNAL_SUM, SHM_INTERNAL_LONG);
}

if (free_source)
free((void *)source);

}
/*****************************************
*
* COLLECT (variable size)
Expand Down
Loading
Loading