Skip to content

Commit

Permalink
Change case handling from Other and add Failure case
Browse files Browse the repository at this point in the history
  • Loading branch information
Robi9 committed Sep 19, 2023
1 parent 2baf773 commit c1bd00d
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 17 deletions.
33 changes: 17 additions & 16 deletions flows/routers/smart.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func (r *SmartRouter) Validate(flow flows.Flow, exits []flows.Exit) error {
}

// and each case test is valid
if c.Type != "has_any_word" {
if c.Type != "has_any_word" && c.Type != "has_category" {
return errors.Errorf("case must be of type 'has_any_words', not %s", c.Type)
}
}
Expand All @@ -119,19 +119,11 @@ func (r *SmartRouter) Route(run flows.FlowRun, step flows.Step, logEvent flows.E

// classify text between categories
categoryName, categoryUUID, err := r.classifyText(run, step, operandAsStr, logEvent)
if err != nil {
return "", "", err
}

// none of our cases matched, so try to use the default
if categoryUUID == "" && r.defaultCategoryUUID != "" {
// evaluate our operand as a string
value, xerr := types.ToXText(env, operand)
if xerr != nil {
run.LogError(step, xerr)
}

categoryName = value.Native()
if err != nil && r.defaultCategoryUUID != "" {
categoryName = "Failure"
categoryUUID = r.defaultCategoryUUID
} else if categoryUUID == "" && r.defaultCategoryUUID != "" {
categoryName = "All Responses"
categoryUUID = r.defaultCategoryUUID
}

Expand All @@ -150,6 +142,10 @@ func SetAPIURL(url string) {
}

func (r *SmartRouter) classifyText(run flows.FlowRun, step flows.Step, operand string, logEvent flows.EventCallback) (string, flows.CategoryUUID, error) {
if len(r.categories) == 1 && len(r.cases) == 0 {
return "", "", nil
}

url := apiUrl + "/v2/repository/nlp/zeroshot/zeroshot-fast-predict"
status := flows.CallStatusSuccess
body := struct {
Expand Down Expand Up @@ -253,7 +249,6 @@ func (r *SmartRouter) classifyText(run flows.FlowRun, step flows.Step, operand s
err = jsonx.Unmarshal(trace.ResponseBody, response)
if err != nil {
run.LogError(step, err)
return "", "", err
}

call := &flows.ZeroshotCall{
Expand All @@ -266,8 +261,14 @@ func (r *SmartRouter) classifyText(run flows.FlowRun, step flows.Step, operand s
var categoryUUID flows.CategoryUUID
categoryUUID = ""

// case with 'other' option
if response.Output.Other {
return "", categoryUUID, nil
for _, c := range r.cases {
if c.Type == "has_category" {
categoryUUID = c.CategoryUUID
}
}
return "Other", categoryUUID, nil
}

for _, category := range r.categories {
Expand Down
3 changes: 3 additions & 0 deletions flows/routers/testdata/_assets.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
},
{
"uuid": "b787ffe3-c21a-46ad-9475-954614b52477"
},
{
"uuid": "22654595-af24-4be4-928e-ed1c268daeb3"
}
]
}
Expand Down
64 changes: 63 additions & 1 deletion flows/routers/testdata/smart.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
{
"uuid": "78ae8f05-f92e-43b2-a886-406eaea1b8e0",
"name": "Supercalifragilisticexpialidocious",
"exit_uuid": "b787ffe3-c21a-46ad-9475-954614b52477"
"exit_uuid": "22654595-af24-4be4-928e-ed1c268daeb3"
},
{
"uuid": "78ae8f05-f92e-43b2-a886-406eaea1b8e0",
Expand Down Expand Up @@ -139,5 +139,67 @@
"waiting_exits": [],
"parent_refs": []
}
},
{
"description": "Result created with matching test result",
"router": {
"type": "smart",
"result_name": "Product",
"categories": [
{
"uuid": "78ae8f05-f92e-43b2-a886-406eaea1b8e0",
"name": "All Responses",
"exit_uuid": "b787ffe3-c21a-46ad-9475-954614b52477"
}
],
"operand": "@(\"How much does it cost?\")",
"cases": [],
"default_category_uuid": "78ae8f05-f92e-43b2-a886-406eaea1b8e0"
},
"results": {
"product": {
"name": "Product",
"value": "All Responses",
"category": "All Responses",
"node_uuid": "64373978-e8f6-4973-b6ff-a2993f3376fc",
"input": "How much does it cost?",
"created_on": "2018-10-18T14:20:30.000123456Z"
}
},
"events": [
{
"category": "All Responses",
"created_on": "2018-10-18T14:20:30.000123456Z",
"input": "How much does it cost?",
"name": "Product",
"step_uuid": "59d74b86-3e2f-4a93-aece-b05d2fdcde0c",
"type": "run_result_changed",
"value": "All Responses"
}
],
"templates": [
"@(\"How much does it cost?\")"
],
"localizables": [
"All Responses"
],
"inspection": {
"dependencies": [],
"issues": [],
"results": [
{
"key": "product",
"name": "Product",
"categories": [
"All Responses"
],
"node_uuids": [
"64373978-e8f6-4973-b6ff-a2993f3376fc"
]
}
],
"waiting_exits": [],
"parent_refs": []
}
}
]

0 comments on commit c1bd00d

Please sign in to comment.