Skip to content

Commit

Permalink
feat(stream_events): stream send()'s to client too
Browse files Browse the repository at this point in the history
  • Loading branch information
gadicc committed Apr 14, 2023
1 parent 46ea977 commit 08daf4f
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 19 deletions.
14 changes: 8 additions & 6 deletions api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def truncateInputs(inputs: dict):

# Inference is ran for every server call
# Reference your preloaded global model variable here.
def inference(all_inputs: dict) -> dict:
async def inference(all_inputs: dict, response) -> dict:
global model
global pipelines
global last_model_id
Expand All @@ -151,6 +151,8 @@ def inference(all_inputs: dict) -> dict:
send_opts.update({"SEND_URL": call_inputs.get("SEND_URL")})
if call_inputs.get("SIGN_KEY", None):
send_opts.update({"SIGN_KEY": call_inputs.get("SIGN_KEY")})
if response:
send_opts.update({"response": response})

if model_inputs == None or call_inputs == None:
return {
Expand Down Expand Up @@ -356,7 +358,7 @@ def inference(all_inputs: dict) -> dict:
)
)

send("inference", "start", {"startRequestId": startRequestId}, send_opts)
await send("inference", "start", {"startRequestId": startRequestId}, send_opts)

# Run patchmatch for inpainting
if call_inputs.get("FILL_MODE", None) == "patchmatch":
Expand Down Expand Up @@ -417,7 +419,7 @@ def inference(all_inputs: dict) -> dict:
send_opts=send_opts,
)
torch.set_grad_enabled(False)
send("inference", "done", {"startRequestId": startRequestId}, send_opts)
await send("inference", "done", {"startRequestId": startRequestId}, send_opts)
result.update({"$timings": getTimings()})
return result

Expand All @@ -435,8 +437,8 @@ def inference(all_inputs: dict) -> dict:
callback = None
if model_inputs.get("callback_steps", None):

def callback(step: int, timestep: int, latents: torch.FloatTensor):
send(
async def callback(step: int, timestep: int, latents: torch.FloatTensor):
await send(
"inference",
"progress",
{"startRequestId": startRequestId, "step": step},
Expand Down Expand Up @@ -473,7 +475,7 @@ def callback(step: int, timestep: int, latents: torch.FloatTensor):
image.save(buffered, format="PNG")
images_base64.append(base64.b64encode(buffered.getvalue()).decode("utf-8"))

send("inference", "done", {"startRequestId": startRequestId}, send_opts)
await send("inference", "done", {"startRequestId": startRequestId}, send_opts)

# Return the results as a dictionary
if len(images_base64) > 1:
Expand Down
7 changes: 6 additions & 1 deletion api/send.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def getTimings():
return timings


def send(type: str, status: str, payload: dict = {}, opts: dict = {}):
async def send(type: str, status: str, payload: dict = {}, opts: dict = {}):
now = get_now()
send_url = opts.get("SEND_URL", SEND_URL)
sign_key = opts.get("SIGN_KEY", SIGN_KEY)
Expand Down Expand Up @@ -102,6 +102,11 @@ def send(type: str, status: str, payload: dict = {}, opts: dict = {}):
if send_url:
futureSession.post(send_url, json=data)

response = opts.get("response")
if response:
print("streaming above")
await response.send(json.dumps(data) + "\n")

# try:
# requests.post(send_url, json=data) # , timeout=0.0000000001)
# except requests.exceptions.ReadTimeout:
Expand Down
21 changes: 16 additions & 5 deletions api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import app as user_src
import traceback
import os
import json

# We do the model load-to-GPU step on server startup
# so the model object is available globally for reuse
Expand All @@ -34,14 +35,21 @@ def healthcheck(request):

# Inference POST handler at '/' is called for every http call from Banana
@server.route("/", methods=["POST"])
def inference(request):
async def inference(request):
try:
model_inputs = response.json.loads(request.json)
all_inputs = response.json.loads(request.json)
except:
model_inputs = request.json
all_inputs = request.json

call_inputs = all_inputs.get("callInputs", None)
stream_events = call_inputs and call_inputs.get("streamEvents", 0) != 0

streaming_response = None
if stream_events:
streaming_response = await request.respond(content_type="application/x-ndjson")

try:
output = user_src.inference(model_inputs)
output = await user_src.inference(all_inputs, streaming_response)
except Exception as err:
output = {
"$error": {
Expand All @@ -52,7 +60,10 @@ def inference(request):
}
}

return response.json(output)
if stream_events:
await streaming_response.send(json.dumps(output) + "\n")
else:
return response.json(output)


if __name__ == "__main__":
Expand Down
28 changes: 21 additions & 7 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ def runTest(name, args, extraCallInputs, extraModelInputs):
"modelInputs": inputs,
"startOnly": False,
}

response = requests.post(f"{BANANA_API_URL}/start/v4/", json=payload)

result = response.json()
callID = result.get("callID")

Expand Down Expand Up @@ -185,13 +187,25 @@ def runTest(name, args, extraCallInputs, extraModelInputs):

else:
test_url = args.get("test_url", None) or TEST_URL
response = requests.post(test_url, json=inputs)
try:
result = response.json()
except requests.exceptions.JSONDecodeError as error:
print(error)
print(response.text)
sys.exit(1)
call_inputs = inputs["callInputs"]
stream_events = call_inputs and call_inputs.get("streamEvents", 0) != 0
print({"stream_events": stream_events})
if stream_events:
result = None
response = requests.post(test_url, json=inputs, stream=True)
for line in response.iter_lines():
if line:
result = json.loads(line)
if not result.get("$timings", None):
print(result)
else:
response = requests.post(test_url, json=inputs)
try:
result = response.json()
except requests.exceptions.JSONDecodeError as error:
print(error)
print(response.text)
sys.exit(1)

finish = time.time() - start
timings = result.get("$timings")
Expand Down

0 comments on commit 08daf4f

Please sign in to comment.