Skip to content

Commit

Permalink
Support more types of object in JSON output
Browse files Browse the repository at this point in the history
Signed-off-by: Ben Firshman <[email protected]>
  • Loading branch information
bfirsh committed Jul 30, 2021
1 parent f451eea commit f843830
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 8 deletions.
2 changes: 1 addition & 1 deletion pkg/cli/predict.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, o
console.Output(string(output))
return nil
} else if output.MimeType == "application/json" {
var obj map[string]interface{}
var obj interface{}
dec := json.NewDecoder(output.Buffer)
if err := dec.Decode(&obj); err != nil {
return err
Expand Down
46 changes: 46 additions & 0 deletions python/cog/json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import json

# Based on keepsake.json

# We load numpy but not torch or tensorflow because numpy loads very fast and
# they're probably using it anyway
# fmt: off
try:
import numpy as np # type: ignore
has_numpy = True
except ImportError:
has_numpy = False
# fmt: on

# Tensorflow takes a solid 10 seconds to import on a modern Macbook Pro, so instead of importing,
# do this instead
def _is_tensorflow_tensor(obj):
# e.g. __module__='tensorflow.python.framework.ops', __name__='EagerTensor'
return (
obj.__class__.__module__.split(".")[0] == "tensorflow"
and "Tensor" in obj.__class__.__name__
)


def _is_torch_tensor(obj):
return (obj.__class__.__module__, obj.__class__.__name__) == ("torch", "Tensor")


class CustomJSONEncoder(json.JSONEncoder):
def default(self, o):
if has_numpy:
if isinstance(o, np.integer):
return int(o)
elif isinstance(o, np.floating):
return float(o)
elif isinstance(o, np.ndarray):
return o.tolist()
if _is_torch_tensor(o):
return o.detach().tolist()
if _is_tensorflow_tensor(o):
return o.numpy().tolist()
return json.JSONEncoder.default(self, o)


def to_json(obj):
return json.dumps(obj, cls=CustomJSONEncoder)
12 changes: 8 additions & 4 deletions python/cog/server/ai_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
get_type_name,
UNSPECIFIED,
)
from ..json import to_json
from ..model import Model, run_model, load_model


Expand Down Expand Up @@ -38,10 +39,13 @@ def handle_request():
except InputValidationError as e:
return jsonify({"error": str(e)})
results.append(run_model(self.model, instance, cleanup_functions))
return jsonify(
{
"predictions": results,
}
return Response(
to_json(
{
"predictions": results,
}
),
mimetype="application/json",
)
except Exception as e:
tb = traceback.format_exc()
Expand Down
3 changes: 2 additions & 1 deletion python/cog/server/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
get_type_name,
UNSPECIFIED,
)
from ..json import to_json
from ..model import Model, run_model, load_model


Expand Down Expand Up @@ -99,7 +100,7 @@ def create_response(self, result, setup_time, run_time):
elif isinstance(result, str):
resp = Response(result, mimetype="text/plain")
else:
resp = jsonify(result)
resp = Response(to_json(result), mimetype="application/json")
resp.headers["X-Setup-Time"] = setup_time
resp.headers["X-Run-Time"] = run_time
return resp
Expand Down
3 changes: 2 additions & 1 deletion python/cog/server/redis_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


from ..input import InputValidationError, validate_and_convert_inputs
from ..json import to_json
from ..model import Model, run_model, load_model


Expand Down Expand Up @@ -210,7 +211,7 @@ def push_result(self, response_queue, result):
}
else:
message = {
"value": json.dumps(result),
"value": to_json(result),
}

sys.stderr.write(f"Pushing successful result to {response_queue}\n")
Expand Down
18 changes: 17 additions & 1 deletion python/cog_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import pytest
from flask.testing import FlaskClient
import numpy as np
from PIL import Image

import cog
Expand Down Expand Up @@ -220,7 +221,7 @@ def predict(self, num1, num2, num3):
client = make_client(Model())
resp = client.post("/predict", data={"num1": 3, "num2": -4, "num3": -4})
assert resp.status_code == 200
assert resp.data == b"-5.0\n"
assert resp.data == b"-5.0"
resp = client.post("/predict", data={"num1": 2, "num2": -4, "num3": -4})
assert resp.status_code == 400
resp = client.post("/predict", data={"num1": 3, "num2": -4.1, "num3": -4})
Expand Down Expand Up @@ -392,6 +393,21 @@ def predict(self):
assert resp.content_length == 195894


def test_json_output_numpy():
class Model(cog.Model):
def setup(self):
pass

def predict(self):
return {"foo": np.float32(1.0)}

client = make_client(Model())
resp = client.post("/predict")
assert resp.status_code == 200
assert resp.content_type == "application/json"
assert resp.data == b'{"foo": 1.0}'


def test_multiple_arguments():
class Model(cog.Model):
def setup(self):
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
flask==2.0.1
numpy==1.21.1
pillow==8.2.0
pytest==6.2.4
PyYAML==5.4.1
Expand Down

0 comments on commit f843830

Please sign in to comment.