Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

shuffle client/server: enable reply flag in connect + cb for ep disco… #25

Merged
merged 1 commit into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,7 @@ case class ExternalUcxClientWorker(val worker: UcpWorker,
@`inline`
def requestAddress(localServer: InetSocketAddress): Unit = {
executor.post(() => shuffleServers.computeIfAbsent("0.0.0.0", _ => {
doConnect(localServer, ExternalAmId.ADDRESS,
UcpConstants.UCP_AM_SEND_FLAG_REPLY)._1
doConnect(localServer, ExternalAmId.ADDRESS)._1
}))
}

Expand All @@ -331,8 +330,7 @@ case class ExternalUcxClientWorker(val worker: UcpWorker,
}

private def doConnect(shuffleServer: InetSocketAddress,
amId: Int = ExternalAmId.CONNECT,
flag: Long = 0): (UcpEndpoint, UcpRequest) = {
amId: Int): (UcpEndpoint, UcpRequest) = {
val endpointParams = new UcpEndpointParams().setPeerErrorHandlingMode()
.setSocketAddress(shuffleServer).sendClientId()
.setErrorHandler(new UcpEndpointErrorHandler() {
Expand All @@ -353,7 +351,8 @@ case class ExternalUcxClientWorker(val worker: UcpWorker,
val req = ep.sendAmNonBlocking(
amId, UcxUtils.getAddress(header), header.remaining(),
UcxUtils.getAddress(workerAddress), workerAddress.remaining(),
UcpConstants.UCP_AM_SEND_FLAG_EAGER | flag, new UcxCallback() {
UcpConstants.UCP_AM_SEND_FLAG_EAGER | UcpConstants.UCP_AM_SEND_FLAG_REPLY,
new UcxCallback() {
override def onSuccess(request: UcpRequest): Unit = {
connectNext()
}
Expand All @@ -368,7 +367,8 @@ case class ExternalUcxClientWorker(val worker: UcpWorker,
}

private def startConnection(shuffleServer: InetSocketAddress): (UcpEndpoint, UcpRequest) = {
connectingServers.computeIfAbsent(shuffleServer, _ => doConnect(shuffleServer))
connectingServers.computeIfAbsent(shuffleServer, _ =>
doConnect(shuffleServer, ExternalAmId.CONNECT))
}

private def getConnection(host: String): UcpEndpoint = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class ExternalUcxServerTransport(
}

def handleConnect(handler: ExternalUcxServerWorker,
clientWorker: UcxWorkerId, address: ByteBuffer): Unit = {
clientWorker: UcxWorkerId): Unit = {
workerMap.getOrElseUpdate(clientWorker.appId, {
new TrieMap[UcxWorkerId, Unit]
}).getOrElseUpdate(clientWorker, Unit)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ case class ExternalUcxServerWorker(val worker: UcpWorker,
private[ucx] val executor = new UcxWorkerThread(
worker, transport.ucxShuffleConf.useWakeup)

private val endpoints = mutable.Set.empty[UcpEndpoint]
private val emptyCallback = () => {}
private val endpoints = mutable.HashMap.empty[UcpEndpoint, () => Unit]
private val listener = worker.newListener(
new UcpListenerParams().setSockAddr(new InetSocketAddress("0.0.0.0", port))
.setConnectionHandler((ucpConnectionRequest: UcpConnectionRequest) => {
Expand All @@ -43,9 +44,9 @@ case class ExternalUcxServerWorker(val worker: UcpWorker,
new UcpEndpointParams().setConnectionRequest(ucpConnectionRequest)
.setPeerErrorHandlingMode().setErrorHandler(errorHandler)
.setName(s"Endpoint to $clientAddress"))
endpoints.add(ep)
endpoints.getOrElseUpdate(ep, emptyCallback)
} catch {
case e: UcxException => logWarning(s"Accept $clientAddress fail: $e")
case e: UcxException => logError(s"Accept $clientAddress fail: $e")
}
}))

Expand All @@ -57,7 +58,7 @@ case class ExternalUcxServerWorker(val worker: UcpWorker,
} else {
logWarning(s"Ep $ucpEndpoint got an error: $errorString")
}
endpoints.remove(ucpEndpoint)
endpoints.remove(ucpEndpoint).foreach(_())
ucpEndpoint.close()
}
}
Expand Down Expand Up @@ -94,15 +95,15 @@ case class ExternalUcxServerWorker(val worker: UcpWorker,

// AM to get worker address for client worker and connect server workers to it
worker.setAmRecvHandler(ExternalAmId.CONNECT,
(headerAddress: Long, headerSize: Long, amData: UcpAmData, _: UcpEndpoint) => {
(headerAddress: Long, headerSize: Long, amData: UcpAmData, ep: UcpEndpoint) => {
val header = UnsafeUtils.getByteBufferView(headerAddress, headerSize.toInt)
val shuffleClient = UcxWorkerId.deserialize(header)
val workerAddress = UnsafeUtils.getByteBufferView(amData.getDataAddress,
amData.getLength.toInt)
val copiedAddress = ByteBuffer.allocateDirect(workerAddress.remaining)
copiedAddress.put(workerAddress)
connected(shuffleClient, copiedAddress)
transport.handleConnect(this, shuffleClient, copiedAddress)
endpoints.put(ep, () => doDisconnect(shuffleClient))
UcsConstants.STATUS.UCS_OK
}, UcpConstants.UCP_AM_FLAG_WHOLE_MSG )
// Main RPC thread. reply with ucpAddress.
Expand Down Expand Up @@ -131,7 +132,7 @@ case class ExternalUcxServerWorker(val worker: UcpWorker,
}
if (!endpoints.isEmpty) {
logInfo(s"$workerId closing ${endpoints.size} eps")
endpoints.map(
endpoints.keys.map(
_.closeNonBlockingForce()).foreach(req =>
while (!req.isCompleted){
worker.progress()
Expand All @@ -158,20 +159,21 @@ case class ExternalUcxServerWorker(val worker: UcpWorker,
}

@`inline`
private def doDisconnect(workerId: UcxWorkerId): Unit = {
private def doDisconnect(shuffleClient: UcxWorkerId): Unit = {
try {
workerReqs.remove(workerId).foreach(reqs => {
workerReqs.remove(shuffleClient).foreach(reqs => {
val inCompletes = reqs.filterNot(_.isCompleted)
if (inCompletes.nonEmpty) {
inCompletes.foreach(worker.cancelRequest(_))
logInfo(s"$workerId canceled ${inCompletes.size} requests")
logInfo(s"Canceled ${inCompletes.size} requests to $shuffleClient")
}
})
Option(shuffleClients.remove(workerId)).foreach(ep => {
Option(shuffleClients.remove(shuffleClient)).foreach(ep => {
ep.closeNonBlockingFlush()
logInfo(s"Disconnect $shuffleClient")
})
} catch {
case e: Throwable => logWarning(s"$workerId disconnect $e")
case e: Throwable => logWarning(s"Disconnect $shuffleClient: $e")
}
}

Expand All @@ -188,6 +190,7 @@ case class ExternalUcxServerWorker(val worker: UcpWorker,
worker.newEndpoint(new UcpEndpointParams()
.setErrorHandler(new UcpEndpointErrorHandler {
override def onError(ucpEndpoint: UcpEndpoint, errorCode: Int, errorString: String): Unit = {
logInfo(s"Connection to $shuffleClient closed: $errorString")
shuffleClients.remove(shuffleClient)
workerReqs.remove(shuffleClient)
ucpEndpoint.close()
Expand All @@ -197,7 +200,7 @@ case class ExternalUcxServerWorker(val worker: UcpWorker,
.setUcpAddress(workerAddress))
})
} catch {
case e: UcxException => logWarning(s"Connect back to $shuffleClient fail: $e")
case e: UcxException => logWarning(s"Connection to $shuffleClient failed: $e")
}
}

Expand All @@ -220,7 +223,7 @@ case class ExternalUcxServerWorker(val worker: UcpWorker,
new mutable.ListBuffer[UcpRequest]()
})

while (reqs.nonEmpty && reqs(0).isCompleted) {
while (reqs.nonEmpty && reqs.head.isCompleted) {
reqs.remove(0)
}
reqs.append(req)
Expand Down