Skip to content

Commit

Permalink
feat: commandline script for joligan server calls
Browse files Browse the repository at this point in the history
  • Loading branch information
pnsuau authored and pnsuau committed Aug 29, 2022
1 parent d18cee9 commit 48ae23b
Show file tree
Hide file tree
Showing 8 changed files with 354 additions and 6 deletions.
101 changes: 101 additions & 0 deletions client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# -*- coding: utf-8 -*-
#
# JoliGAN Python client
#
# Licence:
#
# Copyright 2020-2022 Jolibrain SASU

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""
Here are calls examples that you can use to make API calls to a JoliGAN server. Please note that you have to run a server first.
#Launch a training
python client.py --host jg_server_host --port jg_server_port [joligan commandline options eg --dataroot /path/to/data --model_type cut]
NB: the name given in joligan commandline options will also be the name of the training process.
# List trainings in progress
python client.py --method training_status --host jg_server_host --port jg_server_port
# Stop a training
python client.py --method training_status --host jg_server_host --port jg_server_port --name training_name
"""

import requests
from options.client_options import ClientOptions
import sys


def train(host: str, port: int, name: str, client_options: dict):
train_options = client_options.copy()
del train_options["method"]
del train_options["host"]
del train_options["port"]

json_opt = {}
json_opt["sync"] = False
json_opt["train_options"] = train_options

url = "http://%s:%d" % (host, port) + "/train/%s" % name

x = requests.post(url=url, json=json_opt)

print("Training %s started." % x.json()["name"])


def delete(host: str, port: int, name: str):
url = "http://%s:%d" % (host, port) + "/train/%s" % name
x = requests.delete(url=url)

print("Training %s has been stopped." % x.json()["name"])


def get_status(host: str, port: int):
url = "http://%s:%d" % (host, port) + "/train"

x = requests.get(url=url)

print("There are %i trainings in progress." % (len(x.json()["processes"])))

for process in x.json()["processes"]:
print("Name: %s, status: %s" % (process["name"], process["status"]))


def main_client(args):
if not "launch_training" in args and not "--dataroot" in args:
args += ["--dataroot", "unused"]

client_options = ClientOptions().parse_to_json(args)

host = client_options["host"]
port = client_options["port"]
method = client_options["method"]
name = client_options["name"]

if method == "launch_training":
train(host, port, name, client_options)

elif method == "stop_training":
delete(host, port, name)

elif method == "training_status":
get_status(host, port)
else:
raise


if __name__ == "__main__":
args = sys.argv[1:] # removing the script name
main_client(args)
26 changes: 26 additions & 0 deletions docs/client.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# JoliGAN client python

Here are calls examples that you can use to make API calls to a JoliGAN server. Please note that you have to run a server first.

#### Launch a training

```
python client.py --host jg_server_host --port jg_server_port [joligan commandline options eg --dataroot /path/to/data --model_type cut]
```

NB: the name given in joligan commandline options will also be the name of the training process.

#### List trainings in progress

```
python client.py --method training_status --host jg_server_host --port jg_server_port
```

#### Stop a training

```
python client.py --method training_status --host jg_server_host --port jg_server_port --name training_name
```
12 changes: 8 additions & 4 deletions options/base_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ def initialize(self, parser):
self.initialized = True
return parser

def gather_options(self):
def gather_options(self, args=None):
"""Initialize our parser with basic options(only once).
Add additional model-specific and dataset-specific options.
These options are defined in the <modify_commandline_options> function
Expand All @@ -548,13 +548,13 @@ def gather_options(self):
parser = self.initialize(parser)

# get the basic options
opt, _ = parser.parse_known_args()
opt, _ = parser.parse_known_args(args)

# modify model-related parser options
model_name = opt.model_type
model_option_setter = models.get_option_setter(model_name)
parser = model_option_setter(parser, self.isTrain)
opt, _ = parser.parse_known_args() # parse again with new defaults
opt, _ = parser.parse_known_args(args) # parse again with new defaults

# modify dataset-related parser options
dataset_name = opt.data_dataset_mode
Expand All @@ -563,7 +563,7 @@ def gather_options(self):

# save and return the parser
self.parser = parser
return parser.parse_args()
return parser.parse_args(args)

def print_options(self, opt):
"""Print and save options
Expand Down Expand Up @@ -701,6 +701,10 @@ def parse(self):
opt = self._after_parse(self.opt)
return opt

def parse_to_json(self, args=None):
self.opt = self.gather_options(args)
return self.to_json()

def _json_parse_known_args(self, parser, opt, json_args):
"""
json_args: flattened json of train options
Expand Down
31 changes: 31 additions & 0 deletions options/client_options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from .train_options import TrainOptions
import argparse


class ClientOptions(TrainOptions):
def initialize(self, parser):

parser = TrainOptions.initialize(self, parser)

parser.add_argument(
"--method",
type=str,
default="launch_training",
choices=["launch_training", "stop_training", "training_status"],
)

parser.add_argument(
"--host",
type=str,
required=True,
help="joligan server host",
)

parser.add_argument(
"--port",
type=int,
required=True,
help="joligan server post",
)

return parser
8 changes: 8 additions & 0 deletions scripts/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,11 @@ if [ $OUT != 0 ]; then
else
exit 0
fi

#### Client server test
SERVER_HOST="10.10.77.108"
SERVER_PORT=8000

python3 -m pytest ${current_dir}/../tests/client_test_server.py --host $SERVER_HOST --port $SERVER_PORT


176 changes: 176 additions & 0 deletions tests/client_test_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
from http.server import ThreadingHTTPServer, BaseHTTPRequestHandler
from urllib.parse import urlparse, parse_qs
from signal import signal, SIGINT
from sys import exit
from json import loads, dumps
from argparse import ArgumentParser
from threading import Thread
import sys
import time

sys.path.append(sys.path[0] + "/..")
import client


class HttpHandler(BaseHTTPRequestHandler):
protocol_version = "HTTP/1.1"
error_content_type = "text/plain"
error_message_format = "Error %(code)d: %(message)s"

def do_GET(self):
path, args = self.parse_url()

if path == "/train" and args == {}:
self.write_response(
200,
"application/json",
dumps({"processes": {}}),
)
else:
self.send_error(400, "Invalid path or args")

def do_DELETE(self):
path, args = self.parse_url()

if path == "/train/test_client" and args == {}:
self.write_response(
200,
"application/json",
dumps({"message": "ok", "name": "test_client"}),
)
else:
self.send_error(400, "Invalid path or args")

def do_POST(self):
path, _ = self.parse_url()
body = self.read_body()

if (
path == "/train/test_client"
and self.parse_json(body)["train_options"] is not None
):

self.write_response(
200,
"application/json",
dumps({"message": "ok", "name": "test_client", "status": "running"}),
)
else:
self.send_error(400, "Invalid json received")

def parse_url(self):
url_components = urlparse(self.path)
return url_components.path, parse_qs(url_components.query)

def parse_json(self, content):
try:
return loads(content)
except Exception:
return None

def read_body(self):
try:
content_length = int(self.headers["Content-Length"])
return self.rfile.read(content_length).decode("utf-8")
except Exception:
return None

def write_response(self, status_code, content_type, content):
response = content.encode("utf-8")

self.send_response(status_code)
self.send_header("Content-Type", content_type)
self.send_header("Content-Length", str(len(response)))
self.end_headers()
self.wfile.write(response)

def version_string(self):
return "Tiny Http Server"

def log_error(self, format, *args):
pass


def start_server(host, port):
server_address = (host, port)
httpd = ThreadingHTTPServer(server_address, HttpHandler)
print(f"Server started on {host}:{port}")
httpd.serve_forever()


def shutdown_handler(signum, frame):
print("Shutting down server")
exit(0)


def main():
signal(SIGINT, shutdown_handler)
parser = ArgumentParser(description="Start a tiny HTTP/1.1 server")
parser.add_argument(
"--host",
type=str,
action="store",
default="localhost",
help="Server host (default: localhost)",
)
parser.add_argument(
"--port",
type=int,
action="store",
default=8000,
help="Server port (default: 8000)",
)
args = parser.parse_args()

Thread(target=start_server, daemon=True, args=[args.host, args.port]).start()

time.sleep(1)

client.main_client(
args=[
"--host",
args.host,
"--port",
str(args.port),
"--method",
"training_status",
"--name",
"test_client",
]
)

time.sleep(1)

client.main_client(
args=[
"--host",
args.host,
"--port",
str(args.port),
"--method",
"stop_training",
"--name",
"test_client",
]
)

time.sleep(1)

client.main_client(
args=[
"--host",
args.host,
"--port",
str(args.port),
"--method",
"launch_training",
"--name",
"test_client",
"--dataroot",
"fake_dataroot",
]
)


if __name__ == "__main__":
main()
Loading

0 comments on commit 48ae23b

Please sign in to comment.