From b36eba3b663a28fedaa7e33daf2c01386d885975 Mon Sep 17 00:00:00 2001 From: forgive_dengkai Date: Wed, 25 Oct 2023 16:37:40 +0800 Subject: [PATCH 01/18] fix bug of log Signed-off-by: forgive_dengkai --- .../webank/eggroll/core/env/SysInfoLinux.java | 17 +++++++++++++++-- .../core/resourcemanager/ClusterManager.scala | 19 ++++++++++++------- .../ClusterManagerBootstrap.scala | 4 ++++ .../ClusterResourceManager.scala | 5 +---- 4 files changed, 32 insertions(+), 13 deletions(-) diff --git a/jvm/core/main/java/com/webank/eggroll/core/env/SysInfoLinux.java b/jvm/core/main/java/com/webank/eggroll/core/env/SysInfoLinux.java index aa87c074e..e92fae9b9 100644 --- a/jvm/core/main/java/com/webank/eggroll/core/env/SysInfoLinux.java +++ b/jvm/core/main/java/com/webank/eggroll/core/env/SysInfoLinux.java @@ -166,8 +166,21 @@ public int getGpuNumber() throws IOException { shellExecutorClk.execute(); String cmdReturnString = shellExecutorClk.getOutput(); if (StringUtils.isNotEmpty(cmdReturnString)) - result = cmdReturnString.split("\n").length-1; - }catch(Exception ignore){} + { + String[] elems = cmdReturnString.split("\n"); + for(String elem:elems){ + if(elem.contains("NVIDIA")){ + result=result+1; + } + } + } + if(result==0){ + System.err.println("nvidia-smi cmd return "+cmdReturnString); + } + + }catch(Exception ignore){ + } + System.err.println("nvidia-smi gpu return "+result); return result; } diff --git a/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/ClusterManager.scala b/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/ClusterManager.scala index e5c8101cf..a14f5ba65 100644 --- a/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/ClusterManager.scala +++ b/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/ClusterManager.scala @@ -104,16 +104,17 @@ object ClusterManagerService extends Logging { residualHeartbeatMap.remove(k) } catch { case e: Throwable => - e.printStackTrace() logError(s"kill residual processor error: ${e.getMessage}") } }) } catch { case e: Throwable => - e.printStackTrace() + logError("") + }finally { + Thread.sleep(CONFKEY_NODE_MANAGER_HEARTBEAT_INTERVAL.get().toInt) } - Thread.sleep(CONFKEY_NODE_MANAGER_HEARTBEAT_INTERVAL.get().toInt) + } },"REDIDUAL_PROCESS_CHECK_THREAD" ) @@ -130,8 +131,10 @@ object ClusterManagerService extends Logging { case e: Throwable => e.printStackTrace() + }finally { + Thread.sleep(CONFKEY_NODE_MANAGER_HEARTBEAT_INTERVAL.get().toInt) } - Thread.sleep(CONFKEY_NODE_MANAGER_HEARTBEAT_INTERVAL.get().toInt) + } } ,"NODE_PROCESS_CHECK_THREAD") @@ -241,14 +244,16 @@ object ClusterManagerService extends Logging { catch { case e: Throwable=> logError(s"session watcher handle session ${session.id} error ${e.getMessage}") - e.printStackTrace() +// e.printStackTrace() } } - Thread.sleep(EGGROLL_SESSION_STATUS_CHECK_INTERVAL_MS.get().toLong) + }catch { case e: Throwable=> logError(s"session watcher handle error ${e.getMessage}") - e.printStackTrace() +// e.printStackTrace() + }finally { + Thread.sleep(EGGROLL_SESSION_STATUS_CHECK_INTERVAL_MS.get().toLong) } } } diff --git a/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/ClusterManagerBootstrap.scala b/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/ClusterManagerBootstrap.scala index be15cecb5..72b38e627 100644 --- a/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/ClusterManagerBootstrap.scala +++ b/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/ClusterManagerBootstrap.scala @@ -19,6 +19,7 @@ class ClusterManagerBootstrap extends BootstrapBase with Logging { private var port = 0 private var standaloneTag = "0" + //private var sessionId = "er_session_null" override def init(args: Array[String]): Unit = { @@ -198,6 +199,8 @@ class ClusterManagerBootstrap extends BootstrapBase with Logging { } override def start(): Unit = { + + RdbConnectionPool.dataSource.start() // TODO:0: use user's config val server = GrpcServerUtils.createServer(port = this.port, grpcServices = List(new CommandService,new ClusterManagerExtendTransferService)) server.start() @@ -206,6 +209,7 @@ class ClusterManagerBootstrap extends BootstrapBase with Logging { StaticErConf.setPort(port) logInfo(s"$standaloneTag server started at port $port") println(s"$standaloneTag server started at port $port") + ClusterResourceManager.start() ClusterManagerService.start() } diff --git a/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/ClusterResourceManager.scala b/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/ClusterResourceManager.scala index b81a82437..8ef3dbab4 100644 --- a/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/ClusterResourceManager.scala +++ b/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/ClusterResourceManager.scala @@ -152,7 +152,6 @@ object ClusterResourceManager extends Logging{ }catch { case e:Throwable => { logError("dispatch resource error: "+e.getMessage) - e.printStackTrace() } } @@ -163,9 +162,9 @@ object ClusterResourceManager extends Logging{ var maxResourceCountThread = new Thread(()=> { + logInfo("SYSTEM-RESOURCE-COUNT-THREAD start") while (true) { try{ - var serverNodes = getServerNodeWithResource(); countMaxResource(serverNodes:Array[ErServerNode]) serverNodes.map( node=>{ @@ -174,7 +173,6 @@ object ClusterResourceManager extends Logging{ }catch { case e:Throwable => { logError("count resource error: "+e.getMessage) - e.printStackTrace() } } Thread.sleep(EGGROLL_RESOURCE_COUNT_INTERVAL.get().toInt) @@ -206,7 +204,6 @@ object ClusterResourceManager extends Logging{ }catch { case e:Throwable => logError(s"lock clean error ${e.getMessage}") - // e.printStackTrace() } }) Thread.sleep(EGGROLL_RESOURCE_LOCK_EXPIRE_INTERVAL.get().toInt) From 511d818f7eac24daaa89efd5197087b313777200 Mon Sep 17 00:00:00 2001 From: forgive_dengkai Date: Wed, 25 Oct 2023 16:51:22 +0800 Subject: [PATCH 02/18] fix status Signed-off-by: forgive_dengkai --- .../webank/eggroll/core/deepspeed/job/JobServiceHandler.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jvm/core/main/scala/com/webank/eggroll/core/deepspeed/job/JobServiceHandler.scala b/jvm/core/main/scala/com/webank/eggroll/core/deepspeed/job/JobServiceHandler.scala index 98da0e740..59a6ce2cc 100644 --- a/jvm/core/main/scala/com/webank/eggroll/core/deepspeed/job/JobServiceHandler.scala +++ b/jvm/core/main/scala/com/webank/eggroll/core/deepspeed/job/JobServiceHandler.scala @@ -404,7 +404,7 @@ object JobServiceHandler extends Logging { } logInfo(s"killing job send to node over $sessionId") var now = System.currentTimeMillis() - smDao.updateSessionMain(sessionMeta.copy(status = SessionStatus.ERROR), afterCall = defaultSessionCallback) + smDao.updateSessionMain(sessionMeta.copy(status = SessionStatus.KILLED), afterCall = defaultSessionCallback) var cost = System.currentTimeMillis()-now logInfo(s"killing job update session main over $sessionId ,cost $cost") }finally { From 89b7dfb70c5410b3d34571af37fea4b8a7c1be11 Mon Sep 17 00:00:00 2001 From: chengtcc <864261919@qq.com> Date: Wed, 25 Oct 2023 21:39:41 +0800 Subject: [PATCH 03/18] add kill,stop job function Signed-off-by: chengtcc <864261919@qq.com> --- python/client/cli/commands/task.py | 30 +++++++++++++++++---- python/client/sdk/api/task.py | 40 ++++++++++++++++++++-------- python/client/sdk/submit/commands.py | 1 + 3 files changed, 55 insertions(+), 16 deletions(-) diff --git a/python/client/cli/commands/task.py b/python/client/cli/commands/task.py index 09991c048..56a00aa2d 100644 --- a/python/client/cli/commands/task.py +++ b/python/client/cli/commands/task.py @@ -61,13 +61,13 @@ def submit(ctx, **kwargs): client: EggrollClient = ctx.obj["client"] session_id = f"deepspeed_session_{datetime.datetime.now().strftime('%Y%m%d-%H%M%S-%f')}" print(f'session_id:{session_id}') - response = client.task.submit(world_size=world_size, files=files, resource_options=resource_options, + client.task.submit(world_size=world_size, files=files, resource_options=resource_options, options=options, command_arguments=command_arguments, session_id=session_id) while True: response = client.task.query_status(session_id=session_id) - print(f'task session_id:{session_id} status:{response.status}') - if response.status != "NEW": + print(f'task session_id:{session_id} status:{response["status"]}') + if response["status"] != "NEW": break log_type = kwargs.get("log_type") if not kwargs.get("log_type") else "stdout" response = client.task.get_log(sessionId=session_id, logType=log_type) @@ -80,8 +80,25 @@ def submit(ctx, **kwargs): def query(ctx, **kwargs): client: EggrollClient = ctx.obj["client"] response = client.task.query_status(session_id=kwargs.get("session_id")) - response_json = {"status": response.status} - prettify(response_json) + prettify(response) + + +@task.command("kill", short_help="Kill job") +@click.option("--session-id", type=click.STRING, required=True, help="session id") +@click.pass_context +def kill(ctx, **kwargs): + client: EggrollClient = ctx.obj["client"] + response = client.task.kill_job(session_id=kwargs.get("session_id")) + prettify(response) + + +@task.command("stop", short_help="Stop job") +@click.option("--session-id", type=click.STRING, required=True, help="session id") +@click.pass_context +def stop(ctx, **kwargs): + client: EggrollClient = ctx.obj["client"] + response = client.task.stop_job(session_id=kwargs.get("session_id")) + prettify(response) @task.command("download", short_help="Download task output") @@ -93,6 +110,9 @@ def query(ctx, **kwargs): def download(ctx, **kwargs): client: EggrollClient = ctx.obj["client"] download_dir = kwargs.get("download_dir") + status = client.task.query_status(session_id=kwargs.get("session_id")) + if status["message"]: + return prettify(status) os.makedirs(download_dir, exist_ok=True) with tempfile.TemporaryDirectory() as temp_dir: diff --git a/python/client/sdk/api/task.py b/python/client/sdk/api/task.py index 64933a342..8a4c572ad 100644 --- a/python/client/sdk/api/task.py +++ b/python/client/sdk/api/task.py @@ -96,35 +96,53 @@ def submit( def query_status(self, session_id): query_job_status_request = deepspeed_pb2.QueryJobStatusRequest(session_id=session_id) - return self._get_client().do_sync_request( + response = self._get_client().do_sync_request( query_job_status_request, output_type=deepspeed_pb2.QueryJobStatusResponse, command_uri=JobCommands.QUERY_JOB_STATUS, ) + if not response.status: + return {"code": 0, "message": f"session_id:{session_id} is not exists"} + return {"status": response.status} - def query_session(self, session_id): + def query_job(self, session_id): query_job_request = deepspeed_pb2.QueryJobRequest(session_id=session_id) query_response = self._get_client().do_sync_request( query_job_request, output_type=deepspeed_pb2.QueryJobResponse, command_uri=JobCommands.QUERY_JOB ) return query_response - def kill(self, session_id): + def kill_job(self, session_id): + status = self.query_status(session_id=session_id) + if not status.get("status"): + return status kill_job_request = deepspeed_pb2.KillJobRequest(session_id=session_id) - kill_response = self._get_client().do_sync_request( + response = self._get_client().do_sync_request( kill_job_request, output_type=deepspeed_pb2.KillJobResponse, command_uri=JobCommands.KILL_JOB ) - return kill_response + response = {"session_id":response.session_id} + return response + + def stop_job(self, session_id): + status = self.query_status(session_id=session_id) + if not status.get("status"): + return status + stop_job_request = deepspeed_pb2.StopJobRequest(session_id=session_id) + response = self._get_client().do_sync_request( + stop_job_request, output_type=deepspeed_pb2.StopJobResponse, command_uri=JobCommands.STOP_JOB + ) + response = {"session_id": response.session_id} + return response def await_finished(self, session_id, timeout: int = 0, poll_interval: int = 1): deadline = time.time() + timeout query_response = self.query_status(session_id) while timeout <= 0 or time.time() < deadline: - if query_response.status not in {SessionStatus.NEW, SessionStatus.ACTIVE}: + if query_response.get("status", "") not in {SessionStatus.NEW, SessionStatus.ACTIVE}: break query_response = self.query_status(session_id) time.sleep(poll_interval) - return query_response.status + return query_response["status"] def download_job( self, @@ -235,7 +253,7 @@ def download_job_to( compress_level: int = 1, ): query_status = self.query_status(session_id) - if not query_status.status: + if not query_status.get("status"): raise ValueError(f'not found session_id:{session_id}') download_job_response = self.download_job_v2(session_id, ranks, content_type, compress_method, compress_level) for key, value in download_job_response.items(): @@ -243,8 +261,8 @@ def download_job_to( with open(path, "wb") as f: f.write(value) - @staticmethod - def writer(stream, session_id, result_queue): + # @staticmethod + def writer(self, stream, session_id, result_queue): try: for res in stream: if str(res.code) == "0": @@ -257,7 +275,7 @@ def writer(stream, session_id, result_queue): ret = {"code": res.code, "message": f"info error"} result_queue.put(ret) except Exception as e: - ret = {"code": "112", "message": f" grpc off "} + ret = {"message": ret["status"]} result_queue.put(ret) def cancel_stream(self, session_id, stream, flag): diff --git a/python/client/sdk/submit/commands.py b/python/client/sdk/submit/commands.py index 82d853854..4ef5bb976 100644 --- a/python/client/sdk/submit/commands.py +++ b/python/client/sdk/submit/commands.py @@ -31,6 +31,7 @@ class JobCommands: QUERY_JOB_STATUS = _create_command_uri(prefix, "queryJobStatus") QUERY_JOB = _create_command_uri(prefix, "queryJob") KILL_JOB = _create_command_uri(prefix, "killJob") + STOP_JOB = _create_command_uri(prefix, "stopJob") DOWNLOAD_JOB = _create_command_uri(prefix, "downloadJob") PREPARE_DOWNLOAD_JOB = _create_command_uri(prefix, "prepareJobDownload") From d44b863d05bff2ec96ea181ac4c6e62eded24448 Mon Sep 17 00:00:00 2001 From: chengtcc <864261919@qq.com> Date: Wed, 25 Oct 2023 21:49:26 +0800 Subject: [PATCH 04/18] add status Signed-off-by: chengtcc <864261919@qq.com> --- python/client/sdk/api/task.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/client/sdk/api/task.py b/python/client/sdk/api/task.py index 8a4c572ad..aab675e35 100644 --- a/python/client/sdk/api/task.py +++ b/python/client/sdk/api/task.py @@ -275,7 +275,8 @@ def writer(self, stream, session_id, result_queue): ret = {"code": res.code, "message": f"info error"} result_queue.put(ret) except Exception as e: - ret = {"message": ret["status"]} + ret = self.query_status(session_id) + ret = {"status": ret["status"]} result_queue.put(ret) def cancel_stream(self, session_id, stream, flag): From 61444d8cb785fea14234299406adab8d99c75f73 Mon Sep 17 00:00:00 2001 From: chengtcc <864261919@qq.com> Date: Thu, 26 Oct 2023 10:18:28 +0800 Subject: [PATCH 05/18] update setup Signed-off-by: chengtcc <864261919@qq.com> --- python/client/setup.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/client/setup.py b/python/client/setup.py index 7d44ced71..d90caec44 100644 --- a/python/client/setup.py +++ b/python/client/setup.py @@ -14,10 +14,12 @@ # limitations under the License. # -*- coding: utf-8 -*- +import os from setuptools import find_packages, setup -packages = find_packages(".") +packages = find_packages('..') +filtered_packages = [pkg for pkg in packages if pkg.startswith("client")] package_data = {"": ["*"]} install_requires = [ "click", @@ -42,7 +44,7 @@ "maintainer": None, "maintainer_email": None, "url": "https://fate.fedai.org/", - "packages": packages, + "packages": filtered_packages, "include_package_data": True, "package_data": package_data, "install_requires": install_requires, @@ -50,5 +52,5 @@ "python_requires": ">=3.8", } - +os.chdir('..') setup(**setup_kwargs) From 5e62fdeb570fa69c5069285431bcf41a6257ce01 Mon Sep 17 00:00:00 2001 From: forgive_dengkai Date: Thu, 26 Oct 2023 14:58:12 +0800 Subject: [PATCH 06/18] fix bug of log Signed-off-by: forgive_dengkai --- .../eggroll/core/util/ProcessUtils.java | 13 ++++++++++ .../core/resourcemanager/NodeManager.scala | 25 +++++++++++++------ 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/jvm/core/main/java/com/webank/eggroll/core/util/ProcessUtils.java b/jvm/core/main/java/com/webank/eggroll/core/util/ProcessUtils.java index fbe2c939e..9d4600f99 100644 --- a/jvm/core/main/java/com/webank/eggroll/core/util/ProcessUtils.java +++ b/jvm/core/main/java/com/webank/eggroll/core/util/ProcessUtils.java @@ -3,10 +3,23 @@ import com.sun.jna.Platform; import java.io.*; +import java.lang.management.ManagementFactory; +import java.lang.management.RuntimeMXBean; public class ProcessUtils { + public static int getCurrentPid(){ + RuntimeMXBean runtime = ManagementFactory.getRuntimeMXBean(); + String name = runtime.getName(); // format: "pid@hostname" + try { + return Integer.parseInt(name.substring(0, name.indexOf('@'))); + } catch (Exception e) { + return -1; + } + } + + public static Process createProcess(String command) throws IOException { String[] cmd = new String[] { "/bin/sh", "-c", command }; return Runtime.getRuntime().exec(cmd); diff --git a/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/NodeManager.scala b/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/NodeManager.scala index 6ee6ed044..77e45a100 100644 --- a/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/NodeManager.scala +++ b/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/NodeManager.scala @@ -25,11 +25,11 @@ object NodeManagerMeta { var status=INIT var serverNodeId = -1:Long; var clusterId = -1:Long; - def refreshServerNodeMetaIntoFile(): Unit = { - + var ip:String =StaticErConf.getString(NodeManagerConfKeys.CONFKEY_NODE_MANAGER_HOST, NetUtils.getLocalHost ) ; + var port:Integer = StaticErConf.getString(NodeManagerConfKeys.CONFKEY_NODE_MANAGER_PORT).toInt + def refreshServerNodeMetaIntoFile(): Unit = { var filePath = CoreConfKeys.EGGROLL_DATA_DIR.get()+ StringConstants.SLASH+"NodeManagerMeta"; var gson= new Gson() - FileSystemUtils.fileWriter(filePath, gson.toJson(NodeManagerMeta)) } def loadNodeManagerMetaFromFile():Unit = { @@ -38,14 +38,22 @@ object NodeManagerMeta { var gson = new Gson() var content = FileSystemUtils.fileReader(filePath) var contentMap = gson.fromJson(content,classOf[NodeManagerMeta]); - NodeManagerMeta.serverNodeId = contentMap.serverNodeId - NodeManagerMeta.clusterId = contentMap.clusterId + + if(NodeManagerMeta.ip.equals(contentMap.ip)&&NodeManagerMeta.port==contentMap.port){ + NodeManagerMeta.serverNodeId = contentMap.serverNodeId + NodeManagerMeta.clusterId = contentMap.clusterId + }else{ + System.err.println("load meta file , found invalid content : "+content) + } } } } case class NodeManagerMeta(status :String, - serverNodeId : Long, - clusterId : Long) + serverNodeId : Long, + clusterId : Long, + ip: String, + port:Integer + ) trait NodeManager { def startContainers(sessionMeta: ErSessionMeta): ErSessionMeta @@ -319,8 +327,9 @@ object NodeResourceManager extends Logging { NodeManagerMeta.serverNodeId = nodeHeartBeat.node.id NodeManagerMeta.clusterId = nodeHeartBeat.node.clusterId logInfo(s"get node id ${NodeManagerMeta.serverNodeId} from cluster-manager ") - NodeManagerMeta.refreshServerNodeMetaIntoFile() NodeManagerMeta.status = HEALTHY + NodeManagerMeta.refreshServerNodeMetaIntoFile() + } logInfo(s"get node id ${NodeManagerMeta.serverNodeId} from cluster-manager ") } From 6f7c590ad14586be10b36878d7c85b4f0327213c Mon Sep 17 00:00:00 2001 From: forgive_dengkai Date: Sun, 29 Oct 2023 18:22:58 +0800 Subject: [PATCH 07/18] add extend env conf Signed-off-by: forgive_dengkai --- BUILD_INFO | 2 +- conf/node-extend-env.properties | 0 .../webank/eggroll/core/env/SysInfoLinux.java | 29 ++-- .../webank/eggroll/core/util/NetUtils.java | 151 ++++++++++-------- .../eggroll/core/constant/ConfKeys.scala | 3 +- .../containers/ContainersServiceHandler.scala | 10 +- .../core/resourcemanager/NodeManager.scala | 6 +- .../NodeManagerBootstrap.scala | 12 +- .../webank/eggroll/core/session/ErConf.scala | 7 + jvm/eggroll-shard/pom.xml | 2 +- jvm/pom.xml | 2 +- .../com/webank/eggroll/rollsite/Util.scala | 1 + python/eggroll/__init__.py | 2 +- 13 files changed, 135 insertions(+), 92 deletions(-) create mode 100644 conf/node-extend-env.properties diff --git a/BUILD_INFO b/BUILD_INFO index e190ea29c..750fef4d2 100644 --- a/BUILD_INFO +++ b/BUILD_INFO @@ -1 +1 @@ -eggroll.version=2.5.2 +eggroll.version=2.5.3 diff --git a/conf/node-extend-env.properties b/conf/node-extend-env.properties new file mode 100644 index 000000000..e69de29bb diff --git a/jvm/core/main/java/com/webank/eggroll/core/env/SysInfoLinux.java b/jvm/core/main/java/com/webank/eggroll/core/env/SysInfoLinux.java index e92fae9b9..602287c5f 100644 --- a/jvm/core/main/java/com/webank/eggroll/core/env/SysInfoLinux.java +++ b/jvm/core/main/java/com/webank/eggroll/core/env/SysInfoLinux.java @@ -19,6 +19,8 @@ import com.google.common.annotations.VisibleForTesting; +import com.webank.eggroll.core.constant.ErConfKey; +import com.webank.eggroll.core.constant.NodeManagerConfKeys; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -156,7 +158,14 @@ public int getGpuNumber() throws IOException { int result = 0; try{ - String[] cmd = new String[] { "/bin/sh", "-c", "nvidia-smi --query-gpu=name --format=csv, noheader" }; + ErConfKey shellConfig = NodeManagerConfKeys.CONFKEY_NODE_MANAGER_GPU_NUM_SHELL(); + String shell = shellConfig.get(); + if(StringUtils.isEmpty(shell)){ + return result; + } + String eggrollHome = System.getenv("EGGROLL_HOME"); + String path = eggrollHome+"/bin/gpu/"+shell; + String[] cmd = new String[] { "/bin/sh", "-c", path }; ShellCommandExecutor shellExecutorClk = new ShellCommandExecutor(cmd); // name // NVIDIA Tesla V100-SXM2-32GB @@ -165,22 +174,16 @@ public int getGpuNumber() throws IOException { // NVIDIA Tesla V100-SXM2-32GB shellExecutorClk.execute(); String cmdReturnString = shellExecutorClk.getOutput(); - if (StringUtils.isNotEmpty(cmdReturnString)) - { - String[] elems = cmdReturnString.split("\n"); - for(String elem:elems){ - if(elem.contains("NVIDIA")){ - result=result+1; - } - } + try { + result = Integer.getInteger(cmdReturnString); + }catch (Throwable e){ + } if(result==0){ - System.err.println("nvidia-smi cmd return "+cmdReturnString); + System.err.println("get gpu num exec "+path +" return "+cmdReturnString); } - }catch(Exception ignore){ - } - System.err.println("nvidia-smi gpu return "+result); + }catch(Exception ignore){} return result; } diff --git a/jvm/core/main/java/com/webank/eggroll/core/util/NetUtils.java b/jvm/core/main/java/com/webank/eggroll/core/util/NetUtils.java index 133e229c4..b7b1dfc48 100644 --- a/jvm/core/main/java/com/webank/eggroll/core/util/NetUtils.java +++ b/jvm/core/main/java/com/webank/eggroll/core/util/NetUtils.java @@ -16,14 +16,15 @@ package com.webank.eggroll.core.util; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; import java.net.*; -import java.util.Enumeration; -import java.util.Optional; +import java.util.*; import java.util.concurrent.ThreadLocalRandom; import java.util.regex.Pattern; @@ -52,12 +53,7 @@ public class NetUtils { private static volatile InetAddress LOCAL_ADDRESS = null; public static void main(String[] args) { -// System.out.println(NetUtils.getLocalHost()); -// System.out.println(NetUtils.getAvailablePort()); -// System.out.println(NetUtils.getLocalAddress()); -// System.out.println(NetUtils.getLocalIp()); -// System.out.println(NetUtils.getIpByHost("127.0.0.1")); -// System.out.println(NetUtils.getLocalAddress0("")); + System.out.println(NetUtils.getLocalHost("")); } public static int getRandomPort() { @@ -180,17 +176,19 @@ static InetAddress normalizeV6Address(Inet6Address address) { return address; } - public static String getLocalHost() { - InetAddress address = getLocalAddress(); - return address == null ? LOCALHOST_VALUE : address.getHostAddress(); + public static String getLocalHost(String deviceName) { + String result = ""; + InetAddress address = getLocalAddress(deviceName); + result= address == null ? LOCALHOST_VALUE : address.getHostAddress(); + return result; } - public static InetAddress getLocalAddress() { + public static InetAddress getLocalAddress(String deviceName) { if (LOCAL_ADDRESS != null) { return LOCAL_ADDRESS; } - InetAddress localAddress = getLocalAddress0(""); + InetAddress localAddress = getLocalAddress0(deviceName); LOCAL_ADDRESS = localAddress; return localAddress; } @@ -209,26 +207,26 @@ private static Optional toValidAddress(InetAddress address) { } - public static String getLocalIp() { - - try { - InetAddress inetAddress = getLocalAddress0("eth0"); - if (inetAddress != null) { - return inetAddress.getHostAddress(); - } else { - inetAddress = getLocalAddress0(""); - } - if (inetAddress != null) { - return inetAddress.getHostAddress(); - } else { - throw new RuntimeException("can not get local ip"); - } - - } catch (Throwable e) { - logger.error(e.getMessage(), e); - } - return ""; - } +// public static String getLocalIp() { +// +// try { +// InetAddress inetAddress = getLocalAddress0("eth0"); +// if (inetAddress != null) { +// return inetAddress.getHostAddress(); +// } else { +// inetAddress = getLocalAddress0(""); +// } +// if (inetAddress != null) { +// return inetAddress.getHostAddress(); +// } else { +// throw new RuntimeException("can not get local ip"); +// } +// +// } catch (Throwable e) { +// logger.error(e.getMessage(), e); +// } +// return ""; +// } private static String getIpByEthNum(String ethNum) { try { @@ -258,10 +256,31 @@ public static String getOsName() { String osName = System.getProperty("os.name"); return osName; } + private static InetAddress chooseAddressFromInterface(NetworkInterface network){ + Enumeration addresses = network.getInetAddresses(); + while (addresses.hasMoreElements()) { + try { + Optional addressOp = toValidAddress(addresses.nextElement()); + if (addressOp.isPresent()) { + try { + if (addressOp.get().isReachable(10000)) { + return addressOp.get(); + } + } catch (IOException e) { + // ignore + } + } + } catch (Throwable e) { + logger.warn(e.getMessage()); + } + } + return null; + } private static InetAddress getLocalAddress0(String name) { InetAddress localAddress = null; + InetAddress other = null; try { localAddress = InetAddress.getLocalHost(); Optional addressOp = toValidAddress(localAddress); @@ -271,49 +290,46 @@ private static InetAddress getLocalAddress0(String name) { localAddress = null; } } catch (Throwable e) { - logger.warn(e.getMessage()); + e.printStackTrace(); } +// if(StringUtils.isNotEmpty(name)||localAddress.getHostAddress().equals(LOCALHOST_VALUE)) { - try { - Enumeration interfaces = NetworkInterface.getNetworkInterfaces(); - if (null == interfaces) { - return localAddress; - } - while (interfaces.hasMoreElements()) { - try { + try { + Enumeration interfaces = NetworkInterface.getNetworkInterfaces(); + if (null == interfaces) { + return localAddress; + } + Map networkIMap= Maps.newLinkedHashMap(); + while (interfaces.hasMoreElements()){ NetworkInterface network = interfaces.nextElement(); - if (network.isLoopback() || network.isVirtual() || !network.isUp()) { - continue; - } - if (StringUtils.isNotEmpty(name)) { - if (!network.getName().equals(name)) { + networkIMap.put(network.getName(),network); + } + if (StringUtils.isNotEmpty(name)&&networkIMap.get(name)!=null) { + return chooseAddressFromInterface(networkIMap.get(name)); + } + + List names = Lists.newArrayList(networkIMap.keySet()); + for(Map.Entry entry:networkIMap.entrySet()){ + try { + if (entry.getValue().isLoopback() || entry.getValue().isVirtual() || !entry.getValue().isUp()) { continue; } - } - Enumeration addresses = network.getInetAddresses(); - while (addresses.hasMoreElements()) { - try { - Optional addressOp = toValidAddress(addresses.nextElement()); - if (addressOp.isPresent()) { - try { - if (addressOp.get().isReachable(10000)) { - return addressOp.get(); - } - } catch (IOException e) { - // ignore - } - } - } catch (Throwable e) { - logger.warn(e.getMessage()); + other= chooseAddressFromInterface(entry.getValue()); + if(other!=null){ + break; } + }catch (Throwable e){ + } - } catch (Throwable e) { - logger.warn(e.getMessage()); } + + } catch (Throwable e) { + logger.warn(e.getMessage()); } - } catch (Throwable e) { - logger.warn(e.getMessage()); - } + if(localAddress.getHostAddress().equals(LOCALHOST_VALUE)&&other!=null){ + localAddress=other; + } + return localAddress; } @@ -531,4 +547,5 @@ private static Integer getNumOfIpSegment(String ipSegment, boolean isIpv4) { return Integer.parseInt(ipSegment, 16); } + } diff --git a/jvm/core/main/scala/com/webank/eggroll/core/constant/ConfKeys.scala b/jvm/core/main/scala/com/webank/eggroll/core/constant/ConfKeys.scala index d26225292..a6748bd35 100644 --- a/jvm/core/main/scala/com/webank/eggroll/core/constant/ConfKeys.scala +++ b/jvm/core/main/scala/com/webank/eggroll/core/constant/ConfKeys.scala @@ -141,10 +141,11 @@ object NodeManagerConfKeys { val CONFKEY_NODE_MANAGER_HOST = "eggroll.resourcemanager.nodemanager.host" val CONFKEY_NODE_MANAGER_PORT = "eggroll.resourcemanager.nodemanager.port" var CONFKEY_NODE_MANAGER_ID = "eggroll.resourcemanager.nodemanager.id" + var CONFKEY_NODE_MANAGER_GPU_NUM_SHELL= ErConfKey("eggroll.resourcemanager.nodemanager.gpu.num.shell","nvidia.sh") var CONFKEY_NODE_MANAGER_HEARTBEAT_INTERVAL = ErConfKey("eggroll.resourcemanager.nodemanager.heartbeat.interval",10000) val CONFKEY_NODE_MANAGER_CPU_VCORES = ErConfKey( "eggroll.resourcemanager.nodemanager.cpu.vcores",16) val CONFKEY_NODE_MANAGER_GPU_VCORES = ErConfKey( "eggroll.resourcemanager.nodemanager.gpu.vcores",16) - + var CONFKEY_NODE_MANAGER_NET_DEVICE = "eggroll.resourcemanager.nodemanager.net.device" val CONFKEY_NODE_MANAGER_CONTAINERS_DATA_DIR = "eggroll.resourcemanager.nodemanager.containers.data.dir" diff --git a/jvm/core/main/scala/com/webank/eggroll/core/containers/ContainersServiceHandler.scala b/jvm/core/main/scala/com/webank/eggroll/core/containers/ContainersServiceHandler.scala index 6a57287d9..59b5549f6 100644 --- a/jvm/core/main/scala/com/webank/eggroll/core/containers/ContainersServiceHandler.scala +++ b/jvm/core/main/scala/com/webank/eggroll/core/containers/ContainersServiceHandler.scala @@ -8,7 +8,8 @@ import com.webank.eggroll.core.containers.meta._ import com.webank.eggroll.core.error.PathNotExistException import com.webank.eggroll.core.meta.ErProcessor import com.webank.eggroll.core.resourcemanager.NodeManagerMeta -import com.webank.eggroll.core.session.StaticErConf +import com.webank.eggroll.core.session +import com.webank.eggroll.core.session.{ExtendEnvConf, StaticErConf} import com.webank.eggroll.core.transfer.Extend import com.webank.eggroll.core.util.{Logging, ProcessUtils} import io.grpc.Status @@ -88,14 +89,18 @@ class ContainersServiceHandler(implicit ec: ExecutionContext, private def startDeepspeedContainers(startDeepspeedContainerRequest: StartDeepspeedContainerRequest): StartContainersResponse = { val sessionId = startDeepspeedContainerRequest.sessionId logInfo(s"(sessionId=$sessionId) starting deepspeed containers") + + startDeepspeedContainerRequest.deepspeedConfigs.par.foreach { case (containerId, deepspeedConfig) => + var envMap :Map[String, String] =startDeepspeedContainerRequest.environmentVariables.++(ExtendEnvConf.getAll) + logInfo("containerId "+containerId+"env map : "+envMap) val container = new DeepSpeedContainer( sessionId = sessionId, processorId = containerId, deepspeedContainerConfig = new WarpedDeepspeedContainerConfig(deepspeedConfig), containerWorkspace = getContainerWorkspace(sessionId, deepspeedConfig.rank), commandArguments = startDeepspeedContainerRequest.commandArguments, - environmentVariables = startDeepspeedContainerRequest.environmentVariables, + environmentVariables =envMap , files = startDeepspeedContainerRequest.files, zippedFiles = startDeepspeedContainerRequest.zippedFiles, options = startDeepspeedContainerRequest.options @@ -319,4 +324,5 @@ object ContainersServiceHandler extends Logging { logInfo(s"zipped path: $path") byteStream.toByteArray } + } \ No newline at end of file diff --git a/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/NodeManager.scala b/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/NodeManager.scala index 77e45a100..05b7026c3 100644 --- a/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/NodeManager.scala +++ b/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/NodeManager.scala @@ -5,6 +5,7 @@ import com.webank.eggroll.core.constant.{ClusterManagerConfKeys, CoreConfKeys, N import com.webank.eggroll.core.meta.{ErEndpoint, ErNodeHeartbeat, ErProcessor, ErResource, ErResourceAllocation, ErServerNode, ErSessionMeta} import com.webank.eggroll.core.session.RuntimeErConf import com.webank.eggroll.core.client.ClusterManagerClient +import com.webank.eggroll.core.constant.NodeManagerConfKeys.CONFKEY_NODE_MANAGER_NET_DEVICE import com.webank.eggroll.core.constant.ServerNodeStatus.{HEALTHY, INIT} import com.webank.eggroll.core.session.{RuntimeErConf, StaticErConf} import com.webank.eggroll.core.env.{Shell, SysInfoLinux} @@ -25,7 +26,8 @@ object NodeManagerMeta { var status=INIT var serverNodeId = -1:Long; var clusterId = -1:Long; - var ip:String =StaticErConf.getString(NodeManagerConfKeys.CONFKEY_NODE_MANAGER_HOST, NetUtils.getLocalHost ) ; + var ip:String =StaticErConf.getString(NodeManagerConfKeys.CONFKEY_NODE_MANAGER_HOST, + NetUtils.getLocalHost( StaticErConf.getString(NodeManagerConfKeys.CONFKEY_NODE_MANAGER_NET_DEVICE, "") )) ; var port:Integer = StaticErConf.getString(NodeManagerConfKeys.CONFKEY_NODE_MANAGER_PORT).toInt def refreshServerNodeMetaIntoFile(): Unit = { var filePath = CoreConfKeys.EGGROLL_DATA_DIR.get()+ StringConstants.SLASH+"NodeManagerMeta"; @@ -315,7 +317,7 @@ object NodeResourceManager extends Logging { seq +=1 var nodeHeartBeat = client.nodeHeartbeat(ErNodeHeartbeat (id= seq ,node =queryNodeResource(ErServerNode(id = NodeManagerMeta.serverNodeId, nodeType = ServerNodeTypes.NODE_MANAGER, - endpoint = ErEndpoint(host = StaticErConf.getString(NodeManagerConfKeys.CONFKEY_NODE_MANAGER_HOST, NetUtils.getLocalHost), + endpoint = ErEndpoint(host = StaticErConf.getString(NodeManagerConfKeys.CONFKEY_NODE_MANAGER_HOST, NetUtils.getLocalHost(StaticErConf.getString(NodeManagerConfKeys.CONFKEY_NODE_MANAGER_NET_DEVICE, ""))), port = StaticErConf.getString(NodeManagerConfKeys.CONFKEY_NODE_MANAGER_PORT).toInt), status = NodeManagerMeta.status) )) diff --git a/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/NodeManagerBootstrap.scala b/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/NodeManagerBootstrap.scala index 3111550cc..9d8861e74 100644 --- a/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/NodeManagerBootstrap.scala +++ b/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/NodeManagerBootstrap.scala @@ -4,9 +4,9 @@ import com.webank.eggroll.core.BootstrapBase import com.webank.eggroll.core.command.{CommandRouter, CommandService} import com.webank.eggroll.core.constant._ import com.webank.eggroll.core.containers.ContainersServiceHandler -import com.webank.eggroll.core.ex.grpc.{NodeManagerExtendTransferService} +import com.webank.eggroll.core.ex.grpc.NodeManagerExtendTransferService import com.webank.eggroll.core.meta.{ErProcessor, ErResourceAllocation, ErServerNode, ErSessionMeta} -import com.webank.eggroll.core.session.StaticErConf +import com.webank.eggroll.core.session.{ExtendEnvConf, StaticErConf} import com.webank.eggroll.core.transfer.GrpcServerUtils import com.webank.eggroll.core.util.{CommandArgsUtils, Logging} import io.grpc.Server @@ -28,8 +28,14 @@ class NodeManagerBootstrap extends BootstrapBase with Logging { this.confPath = cmd.getOptionValue('c', "./conf/eggroll.properties") // val sessionId = cmd.getOptionValue('s') - StaticErConf.addProperties(confPath) + StaticErConf.addProperties(confPath) + var extendConfPath:String = this.confPath.replace("eggroll.properties","node-extend-env.properties") + var extendEnvConfFile = new File(extendConfPath) + if(extendEnvConfFile.exists()) { + log.info("load extend env config file : "+extendConfPath) + ExtendEnvConf.addProperties(extendConfPath) + } // register services // To support parameters to NodeManagerService, // we instantiate a NodeManagerService instance here diff --git a/jvm/core/main/scala/com/webank/eggroll/core/session/ErConf.scala b/jvm/core/main/scala/com/webank/eggroll/core/session/ErConf.scala index 70fb038d6..94d6e2018 100644 --- a/jvm/core/main/scala/com/webank/eggroll/core/session/ErConf.scala +++ b/jvm/core/main/scala/com/webank/eggroll/core/session/ErConf.scala @@ -164,6 +164,13 @@ case class RuntimeErConf(prop: Properties = new Properties()) extends ErConf { override protected def getConf(): Properties = conf } +object ExtendEnvConf extends ErConf{ + override def getPort(): Int = ??? + + override def getModuleName(): String = ??? +} + + object StaticErConf extends ErConf { var port: Int = -1 var moduleName: String = _ diff --git a/jvm/eggroll-shard/pom.xml b/jvm/eggroll-shard/pom.xml index b77730e10..67d71640b 100644 --- a/jvm/eggroll-shard/pom.xml +++ b/jvm/eggroll-shard/pom.xml @@ -6,7 +6,7 @@ com.webank.eggroll eggroll-all - 2.5.2 + 2.5.3 eggroll-shard diff --git a/jvm/pom.xml b/jvm/pom.xml index 190a6edb6..322c53e18 100644 --- a/jvm/pom.xml +++ b/jvm/pom.xml @@ -201,7 +201,7 @@ 4.0.0 - 2.5.2 + 2.5.3 512m diff --git a/jvm/roll_site/main/scala/com/webank/eggroll/rollsite/Util.scala b/jvm/roll_site/main/scala/com/webank/eggroll/rollsite/Util.scala index 13eadcfa6..50a8c3de8 100644 --- a/jvm/roll_site/main/scala/com/webank/eggroll/rollsite/Util.scala +++ b/jvm/roll_site/main/scala/com/webank/eggroll/rollsite/Util.scala @@ -22,6 +22,7 @@ object Util { def main(args: Array[String]) { println(hashMD5("abcdefg")) println(hashMD5("abcde")) + } } \ No newline at end of file diff --git a/python/eggroll/__init__.py b/python/eggroll/__init__.py index 50508c349..2272da851 100644 --- a/python/eggroll/__init__.py +++ b/python/eggroll/__init__.py @@ -14,4 +14,4 @@ # limitations under the License. # -__version__ = "2.5.2" +__version__ = "2.5.3" From 40c9bc2ddbd18788a8acefa034a1adde7f33964d Mon Sep 17 00:00:00 2001 From: forgive_dengkai Date: Tue, 31 Oct 2023 08:24:33 +0800 Subject: [PATCH 08/18] add extend env conf Signed-off-by: forgive_dengkai --- bin/gpu/nvidia.sh | 3 +++ 1 file changed, 3 insertions(+) create mode 100755 bin/gpu/nvidia.sh diff --git a/bin/gpu/nvidia.sh b/bin/gpu/nvidia.sh new file mode 100755 index 000000000..78432fbd4 --- /dev/null +++ b/bin/gpu/nvidia.sh @@ -0,0 +1,3 @@ +result=$(nvidia-smi --query-gpu=name --format=csv, noheader|grep 'NVIDIA'|wc -l) +echo $result + From 92030a71242fc427951fe771c01072555349468c Mon Sep 17 00:00:00 2001 From: forgive_dengkai Date: Tue, 31 Oct 2023 16:31:58 +0800 Subject: [PATCH 09/18] fix bug Signed-off-by: forgive_dengkai --- jvm/core/main/java/com/webank/eggroll/core/util/NetUtils.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jvm/core/main/java/com/webank/eggroll/core/util/NetUtils.java b/jvm/core/main/java/com/webank/eggroll/core/util/NetUtils.java index b7b1dfc48..9dab47e5d 100644 --- a/jvm/core/main/java/com/webank/eggroll/core/util/NetUtils.java +++ b/jvm/core/main/java/com/webank/eggroll/core/util/NetUtils.java @@ -285,7 +285,7 @@ private static InetAddress getLocalAddress0(String name) { localAddress = InetAddress.getLocalHost(); Optional addressOp = toValidAddress(localAddress); if (addressOp.isPresent()) { - return addressOp.get(); +// return addressOp.get(); } else { localAddress = null; } @@ -326,7 +326,7 @@ private static InetAddress getLocalAddress0(String name) { } catch (Throwable e) { logger.warn(e.getMessage()); } - if(localAddress.getHostAddress().equals(LOCALHOST_VALUE)&&other!=null){ + if(localAddress==null||localAddress.getHostAddress().equals(LOCALHOST_VALUE)&&other!=null){ localAddress=other; } From 4c00a65de6e5b2864de3ee9b35497ae81844bd65 Mon Sep 17 00:00:00 2001 From: forgive_dengkai Date: Tue, 31 Oct 2023 21:09:55 +0800 Subject: [PATCH 10/18] add extend env conf Signed-off-by: forgive_dengkai --- .../webank/eggroll/core/env/SysInfoLinux.java | 80 +++++++------------ .../core/resourcemanager/NodeManager.scala | 10 +-- 2 files changed, 33 insertions(+), 57 deletions(-) diff --git a/jvm/core/main/java/com/webank/eggroll/core/env/SysInfoLinux.java b/jvm/core/main/java/com/webank/eggroll/core/env/SysInfoLinux.java index 602287c5f..896adc241 100644 --- a/jvm/core/main/java/com/webank/eggroll/core/env/SysInfoLinux.java +++ b/jvm/core/main/java/com/webank/eggroll/core/env/SysInfoLinux.java @@ -154,40 +154,47 @@ private static long getConf(String attr) { public int getGpuNumber() throws IOException { + + String gpus = null; int result = 0; try{ ErConfKey shellConfig = NodeManagerConfKeys.CONFKEY_NODE_MANAGER_GPU_NUM_SHELL(); + String shell = shellConfig.get(); - if(StringUtils.isEmpty(shell)){ - return result; - } + String eggrollHome = System.getenv("EGGROLL_HOME"); + String path = eggrollHome+"/bin/gpu/"+shell; - String[] cmd = new String[] { "/bin/sh", "-c", path }; - ShellCommandExecutor shellExecutorClk = new ShellCommandExecutor(cmd); -// name -// NVIDIA Tesla V100-SXM2-32GB -// NVIDIA Tesla V100-SXM2-32GB -// NVIDIA Tesla V100-SXM2-32GB -// NVIDIA Tesla V100-SXM2-32GB - shellExecutorClk.execute(); - String cmdReturnString = shellExecutorClk.getOutput(); - try { - result = Integer.getInteger(cmdReturnString); - }catch (Throwable e){ - } - if(result==0){ - System.err.println("get gpu num exec "+path +" return "+cmdReturnString); - } + if(StringUtils.isNotEmpty(path)) { + String[] cmd = new String[]{"/bin/sh", "-c", path}; + ShellCommandExecutor shellExecutorClk = new ShellCommandExecutor(cmd); + shellExecutorClk.execute(); + String cmdReturnString = shellExecutorClk.getOutput(); + try { + cmdReturnString=cmdReturnString.replace("\n",""); + cmdReturnString=cmdReturnString.replace("\r",""); + result = new Integer(cmdReturnString); + } catch (Throwable e) { + e.printStackTrace(); + } + - }catch(Exception ignore){} + + System.err.println("get gpu num exec "+path +" return "+cmdReturnString +" result :"+result) ; + }else{ + System.err.println("get gpu shell is not set"); + } + }catch(Exception ignore){ + ignore.printStackTrace(); + } return result; } + public int getProcess(int pid) { @@ -745,38 +752,7 @@ public long getStorageBytesWritten() { * * @param args - arguments to this calculator test */ - public static void main(String[] args) { - SysInfoLinux plugin = new SysInfoLinux(); -// System.out.println("Physical memory Size (bytes) : " -// + plugin.getPhysicalMemorySize()); -// System.out.println("Total Virtual memory Size (bytes) : " -// + plugin.getVirtualMemorySize()); -// System.out.println("Available Physical memory Size (bytes) : " -// + plugin.getAvailablePhysicalMemorySize()); -// System.out.println("Total Available Virtual memory Size (bytes) : " -// + plugin.getAvailableVirtualMemorySize()); -// System.out.println("Number of Processors : " + plugin.getNumProcessors()); -// System.out.println("CPU frequency (kHz) : " + plugin.getCpuFrequency()); -// System.out.println("Cumulative CPU time (ms) : " + -// plugin.getCumulativeCpuTime()); -// System.out.println("Total network read (bytes) : " -// + plugin.getNetworkBytesRead()); -// System.out.println("Total network written (bytes) : " -// + plugin.getNetworkBytesWritten()); -// System.out.println("Total storage read (bytes) : " -// + plugin.getStorageBytesRead()); -// System.out.println("Total storage written (bytes) : " -// + plugin.getStorageBytesWritten()); -// try { -// // Sleep so we can compute the CPU usage -// Thread.sleep(500L); -// } catch (InterruptedException e) { -// // do nothing -// } -// System.out.println("CPU usage % : " + plugin.getCpuUsagePercentage()); - - plugin.getProcess(1000); - } + @VisibleForTesting void setReadCpuInfoFile(boolean readCpuInfoFileValue) { diff --git a/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/NodeManager.scala b/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/NodeManager.scala index 05b7026c3..52652fb52 100644 --- a/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/NodeManager.scala +++ b/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/NodeManager.scala @@ -173,7 +173,7 @@ class NodeManagerService extends NodeManager with Logging { object NodeResourceManager extends Logging { - private var sysInfo= if(Shell.LINUX)new SysInfoLinux else null + private var sysInfo= new SysInfoLinux private var client = new ClusterManagerClient private var resourceEventQueue= new ArrayBlockingQueue[ResourceEvent](100); @@ -233,11 +233,11 @@ object NodeResourceManager extends Logging { } def getGpuSize():Long = { - if(Shell.LINUX){ +// if(Shell.LINUX){ sysInfo.getGpuNumber - }else{ - 0 - } +// }else{ +// 0 +// } } From 42fedf9c3008541e26a23b3a2575ade9425b18c6 Mon Sep 17 00:00:00 2001 From: forgive_dengkai Date: Wed, 1 Nov 2023 08:55:45 +0800 Subject: [PATCH 11/18] add extend env conf Signed-off-by: forgive_dengkai --- .../webank/eggroll/core/env/SysInfoLinux.java | 47 ++++++++++--------- 1 file changed, 26 insertions(+), 21 deletions(-) diff --git a/jvm/core/main/java/com/webank/eggroll/core/env/SysInfoLinux.java b/jvm/core/main/java/com/webank/eggroll/core/env/SysInfoLinux.java index 896adc241..e01212379 100644 --- a/jvm/core/main/java/com/webank/eggroll/core/env/SysInfoLinux.java +++ b/jvm/core/main/java/com/webank/eggroll/core/env/SysInfoLinux.java @@ -2,11 +2,7 @@ package com.webank.eggroll.core.env; -import java.io.BufferedReader; -import java.io.FileInputStream; -import java.io.FileNotFoundException; -import java.io.InputStreamReader; -import java.io.IOException; +import java.io.*; import java.math.BigInteger; import java.nio.charset.Charset; import java.util.ArrayList; @@ -153,22 +149,36 @@ private static long getConf(String attr) { } - public int getGpuNumber() throws IOException { - + public int getGpuNumberV2() throws IOException { + String gpus = null; + int result = 0; + try { + String[] cmd = new String[]{"/bin/sh", "-c", "nvidia-smi --query-gpu=name --format=csv, noheader"}; + ShellCommandExecutor shellExecutorClk = new ShellCommandExecutor(cmd); + shellExecutorClk.execute(); + String cmdReturnString = shellExecutorClk.getOutput(); + if (StringUtils.isNotEmpty(cmdReturnString)) { + String[] elems = cmdReturnString.split("\n"); + for(String e:elems){ + if(e.contains("NVIDIA")) + result=result+1; + } + } + } catch (Exception ignore) { + } + return result; + } + public int getGpuNumber() throws IOException { String gpus = null; int result = 0; try{ - ErConfKey shellConfig = NodeManagerConfKeys.CONFKEY_NODE_MANAGER_GPU_NUM_SHELL(); - String shell = shellConfig.get(); - String eggrollHome = System.getenv("EGGROLL_HOME"); - String path = eggrollHome+"/bin/gpu/"+shell; - - if(StringUtils.isNotEmpty(path)) { + File file = new File(path); + if(StringUtils.isNotEmpty(path)&&file.exists()) { String[] cmd = new String[]{"/bin/sh", "-c", path}; ShellCommandExecutor shellExecutorClk = new ShellCommandExecutor(cmd); shellExecutorClk.execute(); @@ -180,9 +190,6 @@ public int getGpuNumber() throws IOException { } catch (Throwable e) { e.printStackTrace(); } - - - System.err.println("get gpu num exec "+path +" return "+cmdReturnString +" result :"+result) ; }else{ System.err.println("get gpu shell is not set"); @@ -190,6 +197,9 @@ public int getGpuNumber() throws IOException { }catch(Exception ignore){ ignore.printStackTrace(); } + if(result==0){ + result = getGpuNumberV2(); + } return result; } @@ -747,11 +757,6 @@ public long getStorageBytesWritten() { return numDisksBytesWritten; } - /** - * Test the {@link SysInfoLinux}. - * - * @param args - arguments to this calculator test - */ @VisibleForTesting From 40a33d9d1ac91ca39c4174c1bf36fa1c2ca6866f Mon Sep 17 00:00:00 2001 From: forgive_dengkai Date: Thu, 2 Nov 2023 10:04:47 +0800 Subject: [PATCH 12/18] fix bug Signed-off-by: forgive_dengkai --- .../eggroll/core/constant/ConfKeys.scala | 2 +- .../core/resourcemanager/ClusterManager.scala | 51 ++++++++++--------- .../core/resourcemanager/NodeManager.scala | 26 +++++----- .../NodeManagerBootstrap.scala | 5 ++ 4 files changed, 48 insertions(+), 36 deletions(-) diff --git a/jvm/core/main/scala/com/webank/eggroll/core/constant/ConfKeys.scala b/jvm/core/main/scala/com/webank/eggroll/core/constant/ConfKeys.scala index a6748bd35..13fe49948 100644 --- a/jvm/core/main/scala/com/webank/eggroll/core/constant/ConfKeys.scala +++ b/jvm/core/main/scala/com/webank/eggroll/core/constant/ConfKeys.scala @@ -127,7 +127,7 @@ object ClusterManagerConfKeys { val CONFKEY_CLUSTER_MANAGER_DATASOURCE_DB_DEFAULT_AUTO_COMMIT = "eggroll.resourcemanager.clustermanager.datasource.db.default.auto.commit" val CONFKEY_CLUSTER_MANAGER_HOST = "eggroll.resourcemanager.clustermanager.host" val CONFKEY_CLUSTER_MANAGER_PORT = "eggroll.resourcemanager.clustermanager.port" - var CONFKEY_CLUSTER_MANAGER_NODE_HEARTBEAT_EXPIRED_COUNT = ErConfKey("eggroll.resourcemanager.clustermanager.node.heartbeat.expire.count",200) + var CONFKEY_CLUSTER_MANAGER_NODE_HEARTBEAT_EXPIRED_COUNT = ErConfKey("eggroll.resourcemanager.clustermanager.node.heartbeat.expire.count",30) val EGGROLL_RESOURCEMANAGER_CLUSTERMANAGER_JDBC_PASSWORD_DECRYPTOR = ErConfKey("eggroll.resourcemanager.clustermanager.jdbc.password.decryptor") val EGGROLL_RESOURCEMANAGER_CLUSTERMANAGER_JDBC_PASSWORD_DECRYPTOR_ARGS = ErConfKey("eggroll.resourcemanager.clustermanager.jdbc.password.decryptor.args") val EGGROLL_RESOURCEMANAGER_CLUSTERMANAGER_JDBC_PASSWORD_DECRYPTOR_ARGS_SPLITER = ErConfKey("eggroll.resourcemanager.clustermanager.jdbc.password.decryptor.args.spliter", ",") diff --git a/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/ClusterManager.scala b/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/ClusterManager.scala index a14f5ba65..4f13cf869 100644 --- a/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/ClusterManager.scala +++ b/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/ClusterManager.scala @@ -359,35 +359,40 @@ class ClusterManagerService extends ClusterManager with Logging { override def nodeHeartbeat(nodeHeartbeat: ErNodeHeartbeat): ErNodeHeartbeat = synchronized { //logInfo(s" nodeHeartbeat ${nodeHeartbeat}") var serverNode = nodeHeartbeat.node - if (serverNode.id == -1) { - var existNode = queryNodeByEndPoint(serverNode) - if (existNode == null) { - logInfo(s"create new node ${serverNode}") - createNewNode(serverNode) - } else { - logInfo(s"node already exist ${existNode}") - serverNode = serverNode.copy(id = existNode.id) - updateNode(serverNode, true, true) - } - - } else { - if (nodeHeartbeatMap.contains(serverNode.id) && (nodeHeartbeatMap(serverNode.id).id < nodeHeartbeat.id)) { - //正常心跳 - updateNode(serverNode, false, true) - } else { - //nodemanger重启过 - var existNode = queryNodeById(serverNode) + if(!serverNode.status.equals(ServerNodeStatus.LOSS)) { + if (serverNode.id == -1) { + var existNode = queryNodeByEndPoint(serverNode) if (existNode == null) { - serverNode = createNewNode(serverNode) + logInfo(s"create new node ${serverNode}") + createNewNode(serverNode) } else { + logInfo(s"node already exist ${existNode}") + serverNode = serverNode.copy(id = existNode.id) updateNode(serverNode, true, true) } + } else { + if (nodeHeartbeatMap.contains(serverNode.id) && (nodeHeartbeatMap(serverNode.id).id < nodeHeartbeat.id)) { + //正常心跳 + updateNode(serverNode, false, true) + } else { + //nodemanger重启过 + var existNode = queryNodeById(serverNode) + if (existNode == null) { + serverNode = createNewNode(serverNode) + } else { + updateNode(serverNode, true, true) + } + } + } + nodeHeartbeatMap.put(serverNode.id, nodeHeartbeat); + nodeHeartbeat.copy(node = serverNode) + }else{ + logInfo(s"receive node ${serverNode.id} quit heart beat") + if (serverNode.id != -1) { + updateNode(serverNode, false, true) } + nodeHeartbeat } - nodeHeartbeatMap.put(serverNode.id, nodeHeartbeat); - nodeHeartbeat.copy(node = serverNode) } - - } diff --git a/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/NodeManager.scala b/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/NodeManager.scala index 52652fb52..56a54bcd5 100644 --- a/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/NodeManager.scala +++ b/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/NodeManager.scala @@ -27,7 +27,7 @@ object NodeManagerMeta { var serverNodeId = -1:Long; var clusterId = -1:Long; var ip:String =StaticErConf.getString(NodeManagerConfKeys.CONFKEY_NODE_MANAGER_HOST, - NetUtils.getLocalHost( StaticErConf.getString(NodeManagerConfKeys.CONFKEY_NODE_MANAGER_NET_DEVICE, "") )) ; + NetUtils.getLocalHost( StaticErConf.getString(NodeManagerConfKeys.CONFKEY_NODE_MANAGER_NET_DEVICE, "eth0") )) ; var port:Integer = StaticErConf.getString(NodeManagerConfKeys.CONFKEY_NODE_MANAGER_PORT).toInt def refreshServerNodeMetaIntoFile(): Unit = { var filePath = CoreConfKeys.EGGROLL_DATA_DIR.get()+ StringConstants.SLASH+"NodeManagerMeta"; @@ -240,6 +240,16 @@ object NodeResourceManager extends Logging { // } } + def tryNodeHeartbeat():ErNodeHeartbeat ={ + client.nodeHeartbeat(ErNodeHeartbeat (id= seq ,node =queryNodeResource(ErServerNode(id = NodeManagerMeta.serverNodeId, + nodeType = ServerNodeTypes.NODE_MANAGER, + endpoint = ErEndpoint(host = NodeManagerMeta.ip, + port = NodeManagerMeta.port), + status = NodeManagerMeta.status)))) + + } + + def getAvailablePhysicalMemorySize():Long ={ if(Shell.LINUX){ @@ -307,22 +317,15 @@ object NodeResourceManager extends Logging { used=r.used.get(), allocated = r.allocated.get())})) } - + var seq :Long = 0 class HeartBeatThread extends Thread{ override def run(){ var notOver : Boolean = true - var seq :Long = 0 + while(notOver){ try { seq +=1 - var nodeHeartBeat = client.nodeHeartbeat(ErNodeHeartbeat (id= seq ,node =queryNodeResource(ErServerNode(id = NodeManagerMeta.serverNodeId, - nodeType = ServerNodeTypes.NODE_MANAGER, - endpoint = ErEndpoint(host = StaticErConf.getString(NodeManagerConfKeys.CONFKEY_NODE_MANAGER_HOST, NetUtils.getLocalHost(StaticErConf.getString(NodeManagerConfKeys.CONFKEY_NODE_MANAGER_NET_DEVICE, ""))), - port = StaticErConf.getString(NodeManagerConfKeys.CONFKEY_NODE_MANAGER_PORT).toInt), - status = NodeManagerMeta.status) - )) - - ) + var nodeHeartBeat = tryNodeHeartbeat() if (nodeHeartBeat != null&&nodeHeartBeat.node!=null) { if (NodeManagerMeta.status.equals(INIT)) { if(nodeHeartBeat.node.id != -1) { @@ -338,7 +341,6 @@ object NodeResourceManager extends Logging { } }catch { case t: Throwable => -// t.printStackTrace() logError("node heart beat error "+t.getMessage) } Thread.sleep(NodeManagerConfKeys.CONFKEY_NODE_MANAGER_HEARTBEAT_INTERVAL.get().toInt) diff --git a/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/NodeManagerBootstrap.scala b/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/NodeManagerBootstrap.scala index 9d8861e74..b2b335c2d 100644 --- a/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/NodeManagerBootstrap.scala +++ b/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/NodeManagerBootstrap.scala @@ -6,6 +6,7 @@ import com.webank.eggroll.core.constant._ import com.webank.eggroll.core.containers.ContainersServiceHandler import com.webank.eggroll.core.ex.grpc.NodeManagerExtendTransferService import com.webank.eggroll.core.meta.{ErProcessor, ErResourceAllocation, ErServerNode, ErSessionMeta} +import com.webank.eggroll.core.resourcemanager.NodeResourceManager.tryNodeHeartbeat import com.webank.eggroll.core.session.{ExtendEnvConf, StaticErConf} import com.webank.eggroll.core.transfer.GrpcServerUtils import com.webank.eggroll.core.util.{CommandArgsUtils, Logging} @@ -144,6 +145,8 @@ class NodeManagerBootstrap extends BootstrapBase with Logging { override def shutdown(): Unit = { println("shutting down") + NodeManagerMeta.status=ServerNodeStatus.LOSS + tryNodeHeartbeat() // Gracefully shut down the ForkJoinPool if (forkJoinPool != null) { forkJoinPool.shutdown() @@ -161,6 +164,8 @@ class NodeManagerBootstrap extends BootstrapBase with Logging { if (server != null) { println("shutting down server") server.shutdown() + + println("server shutdown done") } println("shutting down done") From f9cf9a85aeb47557d141e3d0c153cfc6c123a6ce Mon Sep 17 00:00:00 2001 From: chengtcc <864261919@qq.com> Date: Thu, 2 Nov 2023 14:37:43 +0800 Subject: [PATCH 13/18] fix bug & update version Signed-off-by: chengtcc <864261919@qq.com> --- python/client/cli/commands/task.py | 2 ++ python/client/sdk/api/task.py | 6 ++++-- python/client/setup.py | 14 +++++++------- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/python/client/cli/commands/task.py b/python/client/cli/commands/task.py index 56a00aa2d..7a1ae1cfe 100644 --- a/python/client/cli/commands/task.py +++ b/python/client/cli/commands/task.py @@ -71,6 +71,8 @@ def submit(ctx, **kwargs): break log_type = kwargs.get("log_type") if not kwargs.get("log_type") else "stdout" response = client.task.get_log(sessionId=session_id, logType=log_type) + if response["status"]: + response = client.task.query_status(session_id=session_id) prettify(response) diff --git a/python/client/sdk/api/task.py b/python/client/sdk/api/task.py index aab675e35..767ce24ae 100644 --- a/python/client/sdk/api/task.py +++ b/python/client/sdk/api/task.py @@ -120,7 +120,8 @@ def kill_job(self, session_id): response = self._get_client().do_sync_request( kill_job_request, output_type=deepspeed_pb2.KillJobResponse, command_uri=JobCommands.KILL_JOB ) - response = {"session_id":response.session_id} + status = self.query_status(session_id=session_id) + response = {"session_id":response.session_id, "status": status.get("status")} return response def stop_job(self, session_id): @@ -131,7 +132,8 @@ def stop_job(self, session_id): response = self._get_client().do_sync_request( stop_job_request, output_type=deepspeed_pb2.StopJobResponse, command_uri=JobCommands.STOP_JOB ) - response = {"session_id": response.session_id} + status = self.query_status(session_id=session_id) + response = {"session_id": response.session_id, "status": status.get("status")} return response def await_finished(self, session_id, timeout: int = 0, poll_interval: int = 1): diff --git a/python/client/setup.py b/python/client/setup.py index d90caec44..08a1677de 100644 --- a/python/client/setup.py +++ b/python/client/setup.py @@ -23,19 +23,19 @@ package_data = {"": ["*"]} install_requires = [ "click", - "requests", - "grpcio", - "numba", - "numpy", - "protobuf", - "ruamel.yaml" + "requests<2.26.0", + "grpcio==1.46.3", + "numba==0.56.4", + "numpy==1.19.5", + "protobuf=3.19.6", + "ruamel.yaml==0.16" ] entry_points = {"console_scripts": ["eggroll = client.cli.eggroll:eggroll_cli"]} setup_kwargs = { "name": "eggroll-client", - "version": "2.5.1", + "version": "2.5.3", "description": "Clients for Eggroll", "long_description_content_type": "text/markdown", "long_description": "Clients for Eggroll", From a83abd190d62d7077d9a60740f7e4d721ee016dc Mon Sep 17 00:00:00 2001 From: forgive_dengkai Date: Thu, 2 Nov 2023 23:00:56 +0800 Subject: [PATCH 14/18] fix bug Signed-off-by: forgive_dengkai --- .../core/constant/TypesAndStatus.scala | 2 + .../deepspeed/job/JobServiceHandler.scala | 46 +++++++++++++------ .../core/resourcemanager/ClusterManager.scala | 4 +- .../ClusterResourceManager.scala | 32 +++++++------ 4 files changed, 54 insertions(+), 30 deletions(-) diff --git a/jvm/core/main/scala/com/webank/eggroll/core/constant/TypesAndStatus.scala b/jvm/core/main/scala/com/webank/eggroll/core/constant/TypesAndStatus.scala index 95d20fd94..4bd58e112 100644 --- a/jvm/core/main/scala/com/webank/eggroll/core/constant/TypesAndStatus.scala +++ b/jvm/core/main/scala/com/webank/eggroll/core/constant/TypesAndStatus.scala @@ -146,6 +146,8 @@ object ProcessorStatus { } object SessionStatus { + var WAITING_RESOURCE="WAITING_RESOURCE" + var ALLOCATE_RESOURCE_FAILED="ALLOCATE_RESOURCE_FAILED" val NEW = "NEW" var BEFORE_DESTORY = "BEFORE_DESTORY" val NEW_TIMEOUT = "NEW_TIMEOUT" diff --git a/jvm/core/main/scala/com/webank/eggroll/core/deepspeed/job/JobServiceHandler.scala b/jvm/core/main/scala/com/webank/eggroll/core/deepspeed/job/JobServiceHandler.scala index 59a6ce2cc..eb4356100 100644 --- a/jvm/core/main/scala/com/webank/eggroll/core/deepspeed/job/JobServiceHandler.scala +++ b/jvm/core/main/scala/com/webank/eggroll/core/deepspeed/job/JobServiceHandler.scala @@ -38,26 +38,40 @@ object JobServiceHandler extends Logging { def handleJobKill(killJobRequest: KillJobRequest): KillJobResponse = { val sessionId = killJobRequest.sessionId - killJob(sessionId, isTimeout = false) + killJob(sessionId, isTimeout = false,SessionStatus.KILLED) KillJobResponse(sessionId) } def handleJobStop(stopJobRequest: StopJobRequest): StopJobResponse = { val sessionId = stopJobRequest.sessionId - killJob(sessionId, isTimeout = false) + killJob(sessionId, isTimeout = false,SessionStatus.KILLED) StopJobResponse(sessionId) } def handleJobQuery(queryJobRequest: QueryJobRequest): QueryJobResponse = { val sessionId = queryJobRequest.sessionId - val status = smDao.getSessionMain(sessionId).status + var status = smDao.getSessionMain(sessionId).status + if(status.equals(SessionStatus.WAITING_RESOURCE)){ + status = SessionStatus.NEW + } + if(status.equals(SessionStatus.ALLOCATE_RESOURCE_FAILED)){ + status = SessionStatus.ERROR + } val processors = smDao.getSession(sessionId).processors QueryJobResponse(sessionId = sessionId, status = status, processors = processors) } def handleJobStatusQuery(queryJobStatusRequest: QueryJobStatusRequest): QueryJobStatusResponse = { val sessionId = queryJobStatusRequest.sessionId - val status = smDao.getSessionMain(sessionId).status + var status = smDao.getSessionMain(sessionId).status + if(status.equals(SessionStatus.WAITING_RESOURCE)){ + status = SessionStatus.NEW + } + if(status.equals(SessionStatus.ALLOCATE_RESOURCE_FAILED)){ + status = SessionStatus.ERROR + } + + QueryJobStatusResponse(sessionId = sessionId, status = status) } @@ -203,7 +217,7 @@ object JobServiceHandler extends Logging { val activeCount = session.activeProcCount if (activeCount < expectedWorldSize) { try { - killJob(sessionId, isTimeout = true) + killJob(sessionId, isTimeout = true,SessionStatus.ERROR) } catch { case e: Exception => logError(s"failed to kill job $sessionId", e) @@ -263,8 +277,9 @@ object JobServiceHandler extends Logging { try { //锁不能移到分配资源之前,会造成死锁 ClusterResourceManager.lockSession(sessionId) - if(!ClusterResourceManager.killJobMap.contains(sessionId)) { val registeredSessionMeta = smDao.getSession(submitJobRequest.sessionId) + if(!registeredSessionMeta.status.equals(SessionStatus.KILLED)) { + dispatchedProcessors = dispatchedProcessors.zip(registeredSessionMeta.processors).map { case ((processor, node), registeredProcessor) => (processor.copy(id = registeredProcessor.id), node) @@ -359,31 +374,36 @@ object JobServiceHandler extends Logging { required_old_status = Some(SessionStatus.NEW)) SubmitJobResponse(sessionId, activeProcessors) }else{ - logError(s"kill session ${sessionId} request was found") - throw new ErSessionException(s"kill session ${sessionId} request was found") + logError(s"killed session ${sessionId} request was found") + throw new ErSessionException(s"session ${sessionId} is killed") } } catch { case e: Exception => - killJob(sessionId, isTimeout = false) + killJob(sessionId, isTimeout = false,status = SessionStatus.ERROR) throw e }finally { ClusterResourceManager.unlockSession(sessionId) } } - def killJob(sessionId: String, isTimeout: Boolean): Unit = { + def killJob(sessionId: String, isTimeout: Boolean,status:String): Unit = { logInfo(s"try killing job $sessionId") try { ClusterResourceManager.lockSession(sessionId) - ClusterResourceManager.killJobMap.put(sessionId,System.currentTimeMillis()) if (!smDao.existSession(sessionId)) { return } val sessionMeta = smDao.getSession(sessionId) - if (StringUtils.equalsAny(sessionMeta.status, SessionStatus.FINISHED,SessionStatus.KILLED, SessionStatus.CLOSED, SessionStatus.ERROR)) { + if (StringUtils.equalsAny(sessionMeta.status, SessionStatus.FINISHED,SessionStatus.KILLED, SessionStatus.CLOSED, SessionStatus.ERROR,SessionStatus.ALLOCATE_RESOURCE_FAILED)) { return } + if (StringUtils.equalsAny(sessionMeta.status, SessionStatus.WAITING_RESOURCE)) { + logInfo(s"session id ${sessionId} status change from ${SessionStatus.WAITING_RESOURCE} to ${status}") + smDao.updateSessionStatus(sessionId=sessionId,status = status) + return + } + val serverNodeCrudOperator = new ServerNodeCrudOperator() var nodeAndProcessors = sessionMeta.processors .groupBy(p => p.serverNodeId) @@ -404,7 +424,7 @@ object JobServiceHandler extends Logging { } logInfo(s"killing job send to node over $sessionId") var now = System.currentTimeMillis() - smDao.updateSessionMain(sessionMeta.copy(status = SessionStatus.KILLED), afterCall = defaultSessionCallback) + smDao.updateSessionMain(sessionMeta.copy(status = status), afterCall = defaultSessionCallback) var cost = System.currentTimeMillis()-now logInfo(s"killing job update session main over $sessionId ,cost $cost") }finally { diff --git a/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/ClusterManager.scala b/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/ClusterManager.scala index 4f13cf869..2c335e057 100644 --- a/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/ClusterManager.scala +++ b/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/ClusterManager.scala @@ -146,7 +146,7 @@ object ClusterManagerService extends Logging { var interval = current - session.createTime.getTime logDebug(s"watch deepspeed new session: ${session.id} ${interval} ${maxInterval}") if (interval > maxInterval) { - JobServiceHandler.killJob(session.id, isTimeout = true) + JobServiceHandler.killJob(session.id, isTimeout = true,SessionStatus.ERROR) } } @@ -195,7 +195,7 @@ object ClusterManagerService extends Logging { if (sessionProcessors.exists(_.status == ProcessorStatus.ERROR)) { logInfo(s"session watcher kill session ${session}") try { - killJob(session.id, isTimeout = false) + killJob(session.id, isTimeout = false,SessionStatus.ERROR) } catch { case e: ErSessionException => logError(s"failed to kill session ${session.id}", e) diff --git a/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/ClusterResourceManager.scala b/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/ClusterResourceManager.scala index 8ef3dbab4..78eb94c7e 100644 --- a/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/ClusterResourceManager.scala +++ b/jvm/core/main/scala/com/webank/eggroll/core/resourcemanager/ClusterResourceManager.scala @@ -57,7 +57,7 @@ object ClusterResourceManager extends Logging{ var sessionLockMap = new ConcurrentHashMap[String,ReentrantLock]() - var killJobMap = new ConcurrentHashMap[String,Long]() +// var killJobMap = new ConcurrentHashMap[String,Long]() val applicationQueue = new FifoBroker[ResourceApplication] var resourceEventQueue = new FifoBroker[ResourceEvent] lazy val serverNodeCrudOperator = new ServerNodeCrudOperator() @@ -80,10 +80,11 @@ object ClusterResourceManager extends Logging{ var serverNodes :Array[ErServerNode]= null try { lockSession(resourceApplication.sessionId) - if(killJobMap.contains(resourceApplication.sessionId)){ + var sessionInDb = smDao.getSessionMain(resourceApplication.sessionId) + if(sessionInDb.status.equals(SessionStatus.KILLED)){ logError(s"session ${resourceApplication.sessionId} is already canceled , drop it") - applicationQueue.broker.remove() + resourceApplication.countDown() break() } @@ -91,6 +92,7 @@ object ClusterResourceManager extends Logging{ ) { //过期资源申请 logError(s"expired resource request : ${resourceApplication} !!!") + smDao.updateSessionStatus(sessionId = resourceApplication.sessionId,SessionStatus.ALLOCATE_RESOURCE_FAILED) applicationQueue.broker.remove() break() } @@ -103,7 +105,7 @@ object ClusterResourceManager extends Logging{ Thread.sleep(NodeManagerConfKeys.CONFKEY_NODE_MANAGER_HEARTBEAT_INTERVAL.get().toLong) }while((serverNodes==null||serverNodes.length==0)&&tryCount<2) var enough = checkResourceEnough(serverNodes, resourceApplication) - logInfo(s"check resource is enough ? ${enough}") + logInfo(s"check session id ${resourceApplication.sessionId} resource is enough ? ${enough}") if (!enough) { resourceApplication.resourceExhaustedStrategy match { case ResourceExhaustedStrategy.IGNORE => ; @@ -179,10 +181,6 @@ object ClusterResourceManager extends Logging{ } },"SYSTEM-RESOURCE-COUNT-THREAD") - - - - var lockCleanThread = new Thread(()=> { while (true) { logInfo("lock clean thread , prepare to run") @@ -192,12 +190,13 @@ object ClusterResourceManager extends Logging{ var es:ErSessionMeta = smDao.getSessionMain(k) if(es.updateTime!=null){ var updateTime = es.updateTime.getTime - if(now -updateTime>EGGROLL_RESOURCE_LOCK_EXPIRE_INTERVAL.get().toInt&& (es.status== SessionStatus.KILLED|| - es.status==SessionStatus.ERROR|| - es.status== SessionStatus.CLOSED|| - es.status== SessionStatus.FINISHED)){ + if(now -updateTime>EGGROLL_RESOURCE_LOCK_EXPIRE_INTERVAL.get().toInt&& (es.status.equals( SessionStatus.KILLED)|| + es.status.equals(SessionStatus.ERROR)|| + es.status.equals( SessionStatus.CLOSED)|| + es.status.equals( SessionStatus.FINISHED))|| + es.status.equals(SessionStatus.ALLOCATE_RESOURCE_FAILED) + ){ sessionLockMap.remove(es.id) - killJobMap.remove(es.id) } } @@ -514,8 +513,11 @@ object ClusterResourceManager extends Logging{ } def submitResourceRequest(resourceRequest: ResourceApplication):Unit={ - - + smDao.registerWithResource(ErSessionMeta( + id = resourceRequest.sessionId, + name = resourceRequest.sessionName, + status = SessionStatus.WAITING_RESOURCE) + ) applicationQueue.broker.put(resourceRequest) } From eeb07c820ec7a9452403c0466e57f3cdb4cfb851 Mon Sep 17 00:00:00 2001 From: chengtcc <864261919@qq.com> Date: Fri, 3 Nov 2023 14:14:11 +0800 Subject: [PATCH 15/18] update version Signed-off-by: chengtcc <864261919@qq.com> --- python/client/setup.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/client/setup.py b/python/client/setup.py index 08a1677de..ca05b9301 100644 --- a/python/client/setup.py +++ b/python/client/setup.py @@ -25,10 +25,10 @@ "click", "requests<2.26.0", "grpcio==1.46.3", - "numba==0.56.4", - "numpy==1.19.5", - "protobuf=3.19.6", - "ruamel.yaml==0.16" + "numba==0.53.0", + "numpy==1.23.1", + "protobuf==3.19.6", + "ruamel.yaml==0.16.10" ] entry_points = {"console_scripts": ["eggroll = client.cli.eggroll:eggroll_cli"]} From 3e153428fd6a0704018be6397478dd30a7f59520 Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Tue, 7 Nov 2023 14:21:34 +0800 Subject: [PATCH 16/18] update whitelist Signed-off-by: mgqa34 --- conf/whitelist.json | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/conf/whitelist.json b/conf/whitelist.json index 64aed27ca..b22ed589e 100644 --- a/conf/whitelist.json +++ b/conf/whitelist.json @@ -27,7 +27,8 @@ "FeatureImportance" ], "federatedml.ensemble.basic_algorithms.decision_tree.tree_core.g_h_optim": [ - "SplitInfoPackage" + "SplitInfoPackage", + "SplitInfoPackage2" ], "federatedml.ensemble.basic_algorithms.decision_tree.tree_core.node": [ "Node" From fc9772c849564385401d6f057d04610ec9fb3504 Mon Sep 17 00:00:00 2001 From: mgqa34 Date: Tue, 7 Nov 2023 14:23:59 +0800 Subject: [PATCH 17/18] fix whitelist format Signed-off-by: mgqa34 --- conf/whitelist.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conf/whitelist.json b/conf/whitelist.json index b22ed589e..cf42058ad 100644 --- a/conf/whitelist.json +++ b/conf/whitelist.json @@ -28,7 +28,7 @@ ], "federatedml.ensemble.basic_algorithms.decision_tree.tree_core.g_h_optim": [ "SplitInfoPackage", - "SplitInfoPackage2" + "SplitInfoPackage2" ], "federatedml.ensemble.basic_algorithms.decision_tree.tree_core.node": [ "Node" From 4aeefdf73261b11342e8595e5f7e729d76752d3a Mon Sep 17 00:00:00 2001 From: chengtcc <864261919@qq.com> Date: Wed, 8 Nov 2023 20:13:00 +0800 Subject: [PATCH 18/18] add sleep Signed-off-by: chengtcc <864261919@qq.com> --- python/client/cli/commands/task.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/client/cli/commands/task.py b/python/client/cli/commands/task.py index 7a1ae1cfe..61d2c036f 100644 --- a/python/client/cli/commands/task.py +++ b/python/client/cli/commands/task.py @@ -17,6 +17,7 @@ import tempfile import datetime import click +import time from ..utils.cli_utils import load_yaml, prettify, unzip from client.sdk import EggrollClient @@ -67,6 +68,7 @@ def submit(ctx, **kwargs): while True: response = client.task.query_status(session_id=session_id) print(f'task session_id:{session_id} status:{response["status"]}') + time.sleep(1) if response["status"] != "NEW": break log_type = kwargs.get("log_type") if not kwargs.get("log_type") else "stdout"