Skip to content

Commit

Permalink
Merge pull request #2 from DLR-KI/enhancement/inference
Browse files Browse the repository at this point in the history
refactoring inference and add proprocessing download endpoint
  • Loading branch information
BeFranke authored Aug 13, 2024
2 parents d6bce32 + 9b9a062 commit ff97493
Show file tree
Hide file tree
Showing 9 changed files with 514 additions and 72 deletions.
9 changes: 9 additions & 0 deletions fl_server_api/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ def create_error_response(
"""Generic OpenAPI 403 response."""


error_response_404 = create_error_response(
"Not found",
"Not found",
"The server cannot find the requested resource.",
"Provide valid request data."
)
"""Generic OpenAPI 404 response."""


def custom_preprocessing_hook(endpoints: List[Tuple[str, str, str, Callable]]):
"""
Hide the "/api/dummy/" endpoint from the OpenAPI schema.
Expand Down
4 changes: 4 additions & 0 deletions fl_server_api/serializers/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def to_representation(self, instance):
del data["weights"]
if self.context.get("with-stats", False):
data["stats"] = self.analyze_torch_model(instance)
if isinstance(instance, GlobalModel):
data["has_preprocessing"] = bool(instance.preprocessing)
return data

def analyze_torch_model(self, instance: Model):
Expand Down Expand Up @@ -175,6 +177,7 @@ class ModelSerializerNoWeights(ModelSerializer):
class Meta:
model = Model
exclude = ["polymorphic_ctype", "weights"]
include = ["has_preprocessing"]


class ModelSerializerNoWeightsWithStats(ModelSerializerNoWeights):
Expand All @@ -186,6 +189,7 @@ class Meta:
model = Model
exclude = ["polymorphic_ctype", "weights"]
include = ["stats"]
include = ["has_preprocessing", "stats"]


#######################################################################################################################
Expand Down
104 changes: 104 additions & 0 deletions fl_server_api/tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
#
# SPDX-License-Identifier: Apache-2.0

import base64
from django.core.files.uploadedfile import SimpleUploadedFile
from django.test import TestCase
import json
import io
import pickle
import torch
import torch.nn
from torchvision.transforms.functional import to_pil_image
from uuid import uuid4

from fl_server_core.tests import BASE_URL, Dummy
Expand Down Expand Up @@ -168,3 +172,103 @@ def _inference_result(self, torch_model: torch.nn.Module):
self.assertIsNotNone(inference)
inference_tensor = torch.as_tensor(inference)
self.assertTrue(torch.all(torch.tensor([2, 0, 0]) == inference_tensor))

def test_inference_input_shape_positive(self):
inp = from_torch_tensor(torch.zeros(3, 3))
model = Dummy.create_model(input_shape=[None, 3])
training = Dummy.create_training(actor=self.user, model=model)
input_file = SimpleUploadedFile(
"input.pt",
inp,
content_type="application/octet-stream"
)
response = self.client.post(
f"{BASE_URL}/inference/",
{"model_id": str(training.model.id), "model_input": input_file}
)
self.assertEqual(response.status_code, 200)

def test_inference_input_shape_negative(self):
inp = from_torch_tensor(torch.zeros(3, 3))
model = Dummy.create_model(input_shape=[None, 5])
training = Dummy.create_training(actor=self.user, model=model)
input_file = SimpleUploadedFile(
"input.pt",
inp,
content_type="application/octet-stream"
)
with self.assertLogs("root", level="WARNING") as cm:
response = self.client.post(
f"{BASE_URL}/inference/",
{"model_id": str(training.model.id), "model_input": input_file}
)
self.assertEqual(cm.output, [
"WARNING:django.request:Bad Request: /api/inference/",
])
self.assertEqual(response.status_code, 400)
self.assertEqual(response.json()[0], "Input shape does not match model input shape.")

def test_inference_input_pil_image(self):
img = to_pil_image(torch.zeros(1, 5, 5))
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format="jpeg")
img_byte_arr = img_byte_arr.getvalue()

torch.manual_seed(42)
torch_model = torch.jit.script(torch.nn.Sequential(
torch.nn.Conv2d(1, 2, 3),
torch.nn.Flatten(),
torch.nn.Linear(3*3, 2)
))
model = Dummy.create_model(input_shape=[None, 5, 5], weights=from_torch_module(torch_model))
training = Dummy.create_training(actor=self.user, model=model)
input_file = SimpleUploadedFile(
"input.pt",
img_byte_arr,
content_type="application/octet-stream"
)
response = self.client.post(
f"{BASE_URL}/inference/",
{"model_id": str(training.model.id), "model_input": input_file}
)
self.assertEqual(response.status_code, 200)

results = pickle.loads(response.content)
self.assertEqual({}, results["uncertainty"])
inference = results["inference"]
self.assertIsNotNone(inference)
inference_tensor = torch.as_tensor(inference)
self.assertTrue(torch.all(torch.tensor([0, 0]) == inference_tensor))

def test_inference_input_pil_image_base64(self):
img = to_pil_image(torch.zeros(1, 5, 5))
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format="jpeg")
img_byte_arr = img_byte_arr.getvalue()
inp = base64.b64encode(img_byte_arr)

torch.manual_seed(42)
torch_model = torch.jit.script(torch.nn.Sequential(
torch.nn.Conv2d(1, 2, 3),
torch.nn.Flatten(),
torch.nn.Linear(3*3, 2)
))
model = Dummy.create_model(input_shape=[None, 5, 5], weights=from_torch_module(torch_model))
training = Dummy.create_training(actor=self.user, model=model)
input_file = SimpleUploadedFile(
"input.pt",
inp,
content_type="application/octet-stream"
)
response = self.client.post(
f"{BASE_URL}/inference/",
{"model_id": str(training.model.id), "model_input": input_file}
)
self.assertEqual(response.status_code, 200)

results = pickle.loads(response.content)
self.assertEqual({}, results["uncertainty"])
inference = results["inference"]
self.assertIsNotNone(inference)
inference_tensor = torch.as_tensor(inference)
self.assertTrue(torch.all(torch.tensor([0, 0]) == inference_tensor))
47 changes: 47 additions & 0 deletions fl_server_api/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def test_get_model_metadata(self):
self.assertEqual(str(model.name), response_json["name"])
self.assertEqual(str(model.description), response_json["description"])
self.assertEqual(model.input_shape, response_json["input_shape"])
self.assertFalse(response_json["has_preprocessing"])
# check stats
stats = response_json["stats"]
self.assertIsNotNone(stats)
Expand Down Expand Up @@ -232,6 +233,28 @@ def test_get_model_metadata(self):
self.assertIsNotNone(layer4["output_bytes"])
self.assertIsNotNone(layer4["macs"])

def test_get_model_metadata_with_preprocessing(self):
model_bytes = from_torch_module(torch.nn.Sequential(
torch.nn.Linear(3, 64),
torch.nn.ELU(),
torch.nn.Linear(64, 1),
))
torch_model_preprocessing = from_torch_module(transforms.Compose([
transforms.ToImage(),
transforms.ToDtype(torch.float32, scale=True),
transforms.Normalize(mean=(0.,), std=(1.,)),
]))
model = Dummy.create_model(weights=model_bytes, preprocessing=torch_model_preprocessing, input_shape=[None, 3])
response = self.client.get(f"{BASE_URL}/models/{model.id}/metadata/")
self.assertEqual(200, response.status_code)
self.assertEqual("application/json", response["content-type"])
response_json = response.json()
self.assertEqual(str(model.id), response_json["id"])
self.assertEqual(str(model.name), response_json["name"])
self.assertEqual(str(model.description), response_json["description"])
self.assertEqual(model.input_shape, response_json["input_shape"])
self.assertTrue(response_json["has_preprocessing"])

def test_get_model_metadata_torchscript_model(self):
torchscript_model_bytes = from_torch_module(torch.jit.script(torch.nn.Sequential(
torch.nn.Linear(3, 64),
Expand Down Expand Up @@ -552,6 +575,30 @@ def test_upload_model_preprocessing_v2_Compose_good(self):
self.assertIsNotNone(model.preprocessing)
self.assertTrue(isinstance(model.get_preprocessing_torch_model(), torch.nn.Module))

def test_download_model_preprocessing(self):
torch_model_preprocessing = from_torch_module(torch.jit.script(torch.nn.Sequential(
transforms.Normalize(mean=(0.,), std=(1.,)),
)))
model = Dummy.create_model(owner=self.user, preprocessing=torch_model_preprocessing)
response = self.client.get(f"{BASE_URL}/models/{model.id}/preprocessing/")
self.assertEqual(200, response.status_code)
self.assertEqual("application/octet-stream", response["content-type"])
torch_model = torch.jit.load(io.BytesIO(response.content))
self.assertIsNotNone(torch_model)
self.assertTrue(isinstance(torch_model, torch.nn.Module))

def test_download_model_preprocessing_with_undefined_preprocessing(self):
model = Dummy.create_model(owner=self.user, preprocessing=None)
with self.assertLogs("django.request", level="WARNING") as cm:
response = self.client.get(f"{BASE_URL}/models/{model.id}/preprocessing/")
self.assertEqual(cm.output, [
f"WARNING:django.request:Not Found: /api/models/{model.id}/preprocessing/",
])
self.assertEqual(404, response.status_code)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual(f"Model '{model.id}' has no preprocessing model defined.", response_json["detail"])

@patch("fl_server_ai.trainer.tasks.process_trainer_task.apply_async")
def test_upload_update(self, apply_async: MagicMock):
model = Dummy.create_model(owner=self.user, round=0)
Expand Down
2 changes: 1 addition & 1 deletion fl_server_api/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
{"get": "get_model_metrics", "post": "create_model_metrics"}
), name="model-metrics"),
path("models/<str:id>/preprocessing/", view=Model.as_view(
{"post": "upload_model_preprocessing"}
{"get": "get_model_proprecessing", "post": "upload_model_preprocessing"}
), name="model-preprocessing"),
path("models/<str:id>/swag/", view=Model.as_view({"post": "create_swag_stats"}), name="model-swag"),
# trainings
Expand Down
Loading

0 comments on commit ff97493

Please sign in to comment.