Skip to content

Commit

Permalink
Fixing a couple security concerns in raft-dask nccl unique id gener…
Browse files Browse the repository at this point in the history
…ation (#1785)

Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Divye Gala (https://github.com/divyegala)

URL: #1785
  • Loading branch information
cjnolet authored Aug 30, 2023
1 parent 8bffb76 commit a4c9613
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
8 changes: 4 additions & 4 deletions cpp/include/raft/comms/std_comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,17 +143,17 @@ void build_comms_nccl_ucx(
* @}
*/

inline void nccl_unique_id_from_char(ncclUniqueId* id, char* uniqueId, int size)
inline void nccl_unique_id_from_char(ncclUniqueId* id, char* uniqueId)
{
memcpy(id->internal, uniqueId, size);
memcpy(id->internal, uniqueId, NCCL_UNIQUE_ID_BYTES);
}

inline void get_unique_id(char* uid, int size)
inline void get_nccl_unique_id(char* uid)
{
ncclUniqueId id;
ncclGetUniqueId(&id);

memcpy(uid, id.internal, size);
memcpy(uid, id.internal, NCCL_UNIQUE_ID_BYTES);
}
}; // namespace comms
}; // end namespace raft
12 changes: 6 additions & 6 deletions python/raft-dask/raft_dask/common/nccl.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2020-2022, NVIDIA CORPORATION.
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -26,10 +26,9 @@ from libcpp cimport bool


cdef extern from "raft/comms/std_comms.hpp" namespace "raft::comms":
void get_unique_id(char *uid, int size) except +
void get_nccl_unique_id(char *uid) except +
void nccl_unique_id_from_char(ncclUniqueId *id,
char *uniqueId,
int size) except +
char *uniqueId) except +

cdef extern from "nccl.h":

Expand Down Expand Up @@ -80,8 +79,9 @@ def unique_id():
128-byte unique id : str
"""
cdef char *uid = <char *> malloc(NCCL_UNIQUE_ID_BYTES * sizeof(char))
get_unique_id(uid, NCCL_UNIQUE_ID_BYTES)
get_nccl_unique_id(uid)
c_str = uid[:NCCL_UNIQUE_ID_BYTES-1]
c_str
free(uid)
return c_str

Expand Down Expand Up @@ -132,7 +132,7 @@ cdef class nccl:
self.rank = rank

cdef ncclUniqueId *ident = <ncclUniqueId*>malloc(sizeof(ncclUniqueId))
nccl_unique_id_from_char(ident, commId, NCCL_UNIQUE_ID_BYTES)
nccl_unique_id_from_char(ident, commId)

comm_ = <ncclComm_t*>self.comm

Expand Down

0 comments on commit a4c9613

Please sign in to comment.