Skip to content

Commit

Permalink
updated cli.py to enable HA. (NVIDIA#263)
Browse files Browse the repository at this point in the history
  • Loading branch information
yhwen authored Mar 8, 2022
1 parent 5efd19d commit df48117
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 9 deletions.
79 changes: 71 additions & 8 deletions nvflare/fuel/hci/client/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import signal
import time
import traceback
import threading
from datetime import datetime
from enum import Enum
from functools import partial
Expand All @@ -28,6 +29,8 @@
from nvflare.fuel.hci.reg import CommandModule, CommandModuleSpec, CommandRegister, CommandSpec
from nvflare.fuel.hci.security import hash_password, verify_password
from nvflare.fuel.hci.table import Table
from nvflare.ha.overseer_agent import HttpOverseerAgent
from nvflare.apis.overseer_spec import SP

from .api import AdminAPI
from .api_status import APIStatus
Expand Down Expand Up @@ -89,6 +92,7 @@ def __init__(
self.require_login = require_login
self.credential_type = credential_type
self.user_name = None
self.password = None
self.pwd = None

self.debug = debug
Expand Down Expand Up @@ -123,6 +127,47 @@ def __init__(

signal.signal(signal.SIGUSR1, partial(self.session_signal_handler))

self.ssid = None
self.overseer_agent = self._create_overseer_agent()

if self.credential_type == CredentialType.CERT:
if self.overseer_agent:
self.overseer_agent.set_secure_context(ca_path=ca_cert,
cert_path=client_cert,
prv_key_path=client_key)

self.overseer_agent.start(self.overseer_callback)

def _create_overseer_agent(self):
overseer_agent = HttpOverseerAgent(
overseer_end_point="http://127.0.0.1:5000/api/v1",
project="example_project",
role="admin",
name="localhost",
heartbeat_interval=6,
)

return overseer_agent

def overseer_callback(self, overseer_agent):
sp = overseer_agent.get_primary_sp()
self.set_primary_sp(sp)

def set_primary_sp(self, sp: SP):
if sp and sp.primary is True:
if self.api.host != sp.name or self.api.port != int(sp.admin_port) or self.ssid != sp.service_session_id:
self.api.host = sp.name
self.api.port = int(sp.admin_port)
self.ssid = sp.service_session_id
print(f"Got primary SP. Host: {self.api.host} Admin_port: {self.api.port} SSID: {self.ssid}")

thread = threading.Thread(target=self._login_sp)
thread.start()

def _login_sp(self):
self.do_bye("logout")
self.login()

def session_ended(self, message):
self.write_error(message)
os.kill(os.getpid(), signal.SIGUSR1)
Expand Down Expand Up @@ -389,35 +434,53 @@ def cmdloop(self, intro=None):
pass

def run(self):

try:
while self.api.token is None:
time.sleep(1.0)

# self.api.start_session_monitor(self.session_ended)
self.cmdloop(intro='Type ? to list commands; type "? cmdName" to show usage of a command.')
finally:
self.overseer_agent.end()

def login(self):
if self.require_login:
user_name = input("User Name: ")
if self.user_name:
user_name = self.user_name
else:
user_name = input("User Name: ")

if self.credential_type == CredentialType.PASSWORD:
while True:
pwd = getpass.getpass("Password: ")
if self.password:
pwd = self.password
else:
pwd = getpass.getpass("Password: ")
# print(f"host: {self.api.host} port: {self.api.port}")
self.api.login_with_password(username=user_name, password=pwd)
self.stdout.write(f"login_result: {self.api.login_result} token: {self.api.token}\n{self.prompt}")
if self.api.login_result == "OK":
self.user_name = user_name
self.password = pwd
self.pwd = hash_password(pwd)
break
elif self.api.login_result == "REJECT":
print("Incorrect password - please try again")
else:
print("Communication Error - please try later")
return
return False
elif self.credential_type == CredentialType.CERT:
self.api.login(username=user_name)
if self.api.login_result == "OK":
self.user_name = user_name
elif self.api.login_result == "REJECT":
print("Incorrect user name or certificate")
return
return False
else:
print("Communication Error - please try later")
return

self.api.start_session_monitor(self.session_ended)
self.cmdloop(intro='Type ? to list commands; type "? cmdName" to show usage of a command.')
return False
return True

def print_resp(self, resp: dict):
"""Prints the server response
Expand Down
2 changes: 1 addition & 1 deletion nvflare/fuel/hci/tools/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def main():
require_login=args.with_login,
credential_type=CredentialType.PASSWORD if args.cred_type == "password" else CredentialType.CERT,
debug=args.with_debug,
cli_history_size=args.cli_history_size,
# cli_history_size=args.cli_history_size,
)

client.run()
Expand Down

0 comments on commit df48117

Please sign in to comment.