Skip to content

Commit

Permalink
Comparison done
Browse files Browse the repository at this point in the history
  • Loading branch information
simon-mo committed Jan 15, 2020
1 parent 5af791f commit 87cf856
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 69 deletions.
21 changes: 16 additions & 5 deletions js/src/Components/Compare.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { ModelObject } from "../Utils/ProtoUtil";
import { FC, useState } from "react";
import React from "react";
import { Checkbox, Empty } from 'antd';
import { Checkbox, Empty, Tag } from 'antd';
import { Chart } from "react-google-charts";


Expand All @@ -22,7 +22,11 @@ function runToCheckBox(run: InferenceRun, id: number, pushChangedID: (id: number
popChangedID(id)
}
}}>
{run.model.name}, {run.input_type}, {run.output_type}, {JSON.stringify(run.queryMetadata)},
{id}:
<Tag>{run.model.metadata.framework}</Tag>
<Tag>{run.model.name}</Tag>
<Tag>{run.input_type}</Tag>
<Tag>{run.output_type}</Tag>
</Checkbox> <br></br>
</div>
}
Expand All @@ -37,8 +41,8 @@ function createPlot(runs: InferenceRun[]): JSX.Element {
}

const dataItems = runs.map((value, _, __) => [
value.model.name,
Number.parseFloat(value.queryMetadata["model_runtime_s"]),
value.model.metadata.framework + " " + value.model.name,
Number.parseFloat(value.queryMetadata["model_runtime_s"]) * 1000,
null,
null])

Expand All @@ -50,7 +54,7 @@ function createPlot(runs: InferenceRun[]): JSX.Element {
data={[
[
'Element',
'Latency (s)',
'Latency (ms)',
{ role: 'style' },
{
sourceColumn: 0,
Expand All @@ -71,6 +75,12 @@ function createPlot(runs: InferenceRun[]): JSX.Element {
height: 400,
bar: { groupWidth: '95%' },
legend: { position: 'none' },
hAxis: {
title: "Latency(ms)",
},
yAxis: {
title: "Query"
}
}}
/>
}
Expand All @@ -91,6 +101,7 @@ export const CompareRuntime: FC<CompareProps> = props => {
}

return <div>
<h1>Compare (earliest on top)</h1>
{allRuns.map((value, id, __) => runToCheckBox(value, id, pushSelectedIndex, popSelectedIndex))}

{createPlot(selectedIndex.map((value, _, __) => allRuns[value]))}
Expand Down
2 changes: 1 addition & 1 deletion js/src/Views/Home.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ export const Catalog: FC<CatalogProps> = props => {

let cards = models.map((model: ModelObject, index, arr) => {
return (
<Col span={8}>
<Col span={8} key={index}>
<Card
title={model.name}
style={{ margin: "2px" }}
Expand Down
91 changes: 41 additions & 50 deletions js/src/Views/Model.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import _ from "lodash";
import { ModelzooServicePromiseClient } from "js/generated/modelzoo/protos/services_grpc_web_pb";
import { Empty, Payload, PayloadType } from "js/generated/modelzoo/protos/services_pb";
import { Image, Text } from "js/generated/modelzoo/protos/model_apis_pb";
import React, { Dispatch, FC, useMemo, useReducer } from "react";
import React, { Dispatch, FC, useMemo, useReducer, useEffect } from "react";
import { useParams } from "react-router-dom";
import { ImageInput, ImageOutput } from "../Components/Image";
import { TableOutput } from "../Components/Table";
Expand Down Expand Up @@ -42,7 +42,7 @@ interface ModelInferenceState {
outputElement: JSX.Element;

props: ModelProps | undefined;
seenResponseIDs: Set<number>;
seenResponseIDs: Set<string>;
}

const modelInitialState: ModelInferenceState = {
Expand Down Expand Up @@ -100,6 +100,26 @@ type ModelActionUnion =
| ModelAction
| SetPayloadAction;

function deriveDisplayElement(state: ModelInferenceState, displayPayload: Payload): JSX.Element {
switch (state.inputType) {
case "image":
return <ImageOutput
image_uri={displayPayload.getImage()!.getImageDataUrl()}
></ImageOutput>
case "text":
return <TextsOutput
texts={displayPayload.getText()!.getTextsList()}
></TextsOutput>
case "table":
return <TableOutput
tableProto={displayPayload.getTable()!}
></TableOutput>
default:
message.error("Unknown input type " + state.inputType);
return <div></div>;
}
}

function reducer(
state: ModelInferenceState,
action: ModelActionUnion
Expand Down Expand Up @@ -183,11 +203,7 @@ function reducer(
};

case ModelActionType.SetInput:
state.dispatch!({
type: ModelActionType.SetDisplayResult,
payload: (action as SetInputAction).payload
});

console.log("about to call client inference")
state
.client!.inference((action as SetInputAction).payload, undefined)
.then(resp =>
Expand All @@ -200,49 +216,20 @@ function reducer(
message.error("Can't ping the inference API: " + err.message)
);


return {
...state,
outputElement: <Spin></Spin>,
displayElement: deriveDisplayElement(state, (action as SetInputAction).payload)
};
case ModelActionType.SetDisplayResult:
let displayPayload = (action as SetDisplayAction).payload;
switch (state.inputType) {
case "image":
return {
...state,
displayElement: (
<ImageOutput
image_uri={displayPayload.getImage()!.getImageDataUrl()}
></ImageOutput>
)
};
case "text":
return {
...state,
displayElement: (
<TextsOutput
texts={displayPayload.getText()!.getTextsList()}
></TextsOutput>
)
};
case "table":
return {
...state,
displayElement: (
<TableOutput
tableProto={displayPayload.getTable()!}
></TableOutput>
)
};
default:
message.error("Unknown input type " + state.inputType);
return state;
}

case ModelActionType.SetOutputResult:
let payload = (action as SetOutputAction).payload;
if (state.seenResponseIDs.has(payload.getResponseId())) {
return state;
}
// console.log(state)
// console.log(payload.getResponseId().toString())
// if (state.seenResponseIDs.has(payload.getResponseId().toString())) {
// return state;
// }

const currentRunData: InferenceRun = {
model: state.modelObject!,
Expand All @@ -261,7 +248,7 @@ function reducer(
image_uri={payload.getImage()!.getImageDataUrl()}
></ImageOutput>
),
seenResponseIDs: state.seenResponseIDs.add(payload.getResponseId())
seenResponseIDs: state.seenResponseIDs.add(payload.getResponseId().toString())
};
case "text":
return {
Expand All @@ -271,15 +258,15 @@ function reducer(
texts={payload.getText()!.getTextsList()}
></TextsOutput>
),
seenResponseIDs: state.seenResponseIDs.add(payload.getResponseId())
seenResponseIDs: state.seenResponseIDs.add(payload.getResponseId().toString())
};
case "table":
return {
...state,
outputElement: (
<TableOutput tableProto={payload.getTable()!}></TableOutput>
),
seenResponseIDs: state.seenResponseIDs.add(payload.getResponseId())
seenResponseIDs: state.seenResponseIDs.add(payload.getResponseId().toString())
};
default:
message.error("Unknown output type " + state.outputType);
Expand Down Expand Up @@ -312,12 +299,16 @@ export const Model: FC<ModelProps> = props => {
// Parse props
let { name } = useParams();
let { client, token } = props;
modelInitialState.props = props;
const [state, dispatch] = useReducer(reducer, modelInitialState);

const [state, dispatch] = useReducer(reducer, modelInitialState, (state) => {

console.log("in initializer")
return { ...state, props: props }
});


// Initial Action: fetch model
useMemo(() => {
useEffect(() => {
dispatch({
type: ModelActionType.SetModelName,
modelName: name as string,
Expand Down
27 changes: 17 additions & 10 deletions modelzoo/protos/services_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions modelzoo/protos/services_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ class ImageDownloadResponse(google___protobuf___message___Message):
class Payload(google___protobuf___message___Message):
DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...
type = ... # type: PayloadType
response_id = ... # type: builtin___int

@property
def image(self) -> modelzoo___protos___model_apis_pb2___Image: ...
Expand All @@ -217,17 +218,18 @@ class Payload(google___protobuf___message___Message):
image : typing___Optional[modelzoo___protos___model_apis_pb2___Image] = None,
text : typing___Optional[modelzoo___protos___model_apis_pb2___Text] = None,
table : typing___Optional[modelzoo___protos___model_apis_pb2___Table] = None,
response_id : typing___Optional[builtin___int] = None,
) -> None: ...
@classmethod
def FromString(cls, s: builtin___bytes) -> Payload: ...
def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
if sys.version_info >= (3,):
def HasField(self, field_name: typing_extensions___Literal[u"image",u"payload",u"table",u"text"]) -> builtin___bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"image",u"payload",u"table",u"text",u"type"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"image",u"payload",u"response_id",u"table",u"text",u"type"]) -> None: ...
else:
def HasField(self, field_name: typing_extensions___Literal[u"image",b"image",u"payload",b"payload",u"table",b"table",u"text",b"text"]) -> builtin___bool: ...
def ClearField(self, field_name: typing_extensions___Literal[u"image",b"image",u"payload",b"payload",u"table",b"table",u"text",b"text",u"type",b"type"]) -> None: ...
def ClearField(self, field_name: typing_extensions___Literal[u"image",b"image",u"payload",b"payload",u"response_id",b"response_id",u"table",b"table",u"text",b"text",u"type",b"type"]) -> None: ...
def WhichOneof(self, oneof_group: typing_extensions___Literal[u"payload",b"payload"]) -> typing_extensions___Literal["image","text","table"]: ...

class MetricItems(google___protobuf___message___Message):
Expand Down
2 changes: 1 addition & 1 deletion modelzoo/sugar.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def wrapped(inp, metadata):

started = time.time()
out = func(*args, **kwargs)
metadata["model_runtime_s"] = str((time.time() - started)*1000)
metadata["model_runtime_s"] = str((time.time() - started))

return self._out_transformer(out)
return wrapped
Expand Down

0 comments on commit 87cf856

Please sign in to comment.