Skip to content

Commit

Permalink
Refactor output processing in OnnxInferenceModel prediction methods (#…
Browse files Browse the repository at this point in the history
…465)

* Move createInputTensor to companion object

* Remove duplication between prediction methods

* Cleanup code

* Introduce functions for processing OrtSession.Result and inline predictRawWithShapes function

* Use custom OrtSession.Result processing functions in tests and examples

* Do not extract a map from the output in high-level models, but use OrtSession.Result itself
  • Loading branch information
juliabeliaeva authored Oct 8, 2022
1 parent 5add417 commit 4e2e52e
Show file tree
Hide file tree
Showing 16 changed files with 407 additions and 197 deletions.
6 changes: 3 additions & 3 deletions examples/src/main/kotlin/examples/onnx/faces/Fan2D106.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import examples.transferlearning.getFileFromResource
import org.jetbrains.kotlinx.dl.api.inference.facealignment.Landmark
import org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub
import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels
import org.jetbrains.kotlinx.dl.api.inference.onnx.OrtSessionResultConversions.getFloatArray
import org.jetbrains.kotlinx.dl.api.inference.onnx.facealignment.Fan2D106FaceAlignmentModel
import org.jetbrains.kotlinx.dl.dataset.image.ColorMode
import org.jetbrains.kotlinx.dl.dataset.image.ImageConverter
Expand Down Expand Up @@ -53,11 +54,10 @@ fun main() {
val inputImage = ImageConverter.toBufferedImage(inputFile)
val inputData = preprocessor.apply(inputImage).first

val yhat = it.predictRaw(inputData)
println(yhat.values.toTypedArray().contentDeepToString())
val floats = it.predictRaw(inputData) { output -> output.getFloatArray("fc1") }
println(floats.contentToString())

val landMarks = mutableListOf<Landmark>()
val floats = (yhat["fc1"] as Array<*>)[0] as FloatArray
for (j in floats.indices step 2) {
landMarks.add(Landmark((1 + floats[j]) / 2, (1 + floats[j + 1]) / 2))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package examples.onnx.objectdetection.ssd
import examples.transferlearning.getFileFromResource
import org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub
import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels
import org.jetbrains.kotlinx.dl.api.inference.onnx.OrtSessionResultConversions.getFloatArray
import org.jetbrains.kotlinx.dl.dataset.image.ColorMode
import org.jetbrains.kotlinx.dl.dataset.preprocessing.call
import org.jetbrains.kotlinx.dl.dataset.preprocessing.pipeline
Expand Down Expand Up @@ -45,10 +46,10 @@ fun ssd() {
val inputData = preprocessing.load(getFileFromResource("datasets/detection/image$i.jpg")).first

val start = System.currentTimeMillis()
val yhat = it.predictRaw(inputData)
val yhat = it.predictRaw(inputData) { output -> output.getFloatArray(0)}
val end = System.currentTimeMillis()
println("Prediction took ${end - start} ms")
println(yhat.values.toTypedArray().contentDeepToString())
println(yhat.contentToString())
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import examples.transferlearning.getFileFromResource
import org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub
import org.jetbrains.kotlinx.dl.api.inference.objectdetection.DetectedObject
import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels
import org.jetbrains.kotlinx.dl.api.inference.onnx.OrtSessionResultConversions.get2DFloatArray
import org.jetbrains.kotlinx.dl.api.inference.posedetection.DetectedPose
import org.jetbrains.kotlinx.dl.api.inference.posedetection.MultiPoseDetectionResult
import org.jetbrains.kotlinx.dl.api.inference.posedetection.PoseLandmark
Expand Down Expand Up @@ -51,10 +52,11 @@ fun multiPoseDetectionMoveNet() {
.call(modelType.preprocessor)

val inputData = preprocessor.apply(inputImage).first
val yhat = it.predictRaw(inputData)
println(yhat.values.toTypedArray().contentDeepToString())
val rawPoseLandmarks = it.predictRaw(inputData) { result ->
result.get2DFloatArray("output_0")
}
println(rawPoseLandmarks.contentDeepToString())

val rawPoseLandmarks = (yhat["output_0"] as Array<Array<FloatArray>>)[0]
val poses = rawPoseLandmarks.mapNotNull { floats ->
val probability = floats[55]
if (probability < 0.05) return@mapNotNull null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package examples.onnx.posedetection.singlepose
import examples.transferlearning.getFileFromResource
import org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub
import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels
import org.jetbrains.kotlinx.dl.api.inference.onnx.OrtSessionResultConversions.get2DFloatArray
import org.jetbrains.kotlinx.dl.api.inference.posedetection.DetectedPose
import org.jetbrains.kotlinx.dl.api.inference.posedetection.PoseLandmark
import org.jetbrains.kotlinx.dl.dataset.image.ColorMode
Expand Down Expand Up @@ -49,10 +50,10 @@ fun poseDetectionMoveNet() {

val inputData = preprocessing.apply(image).first

val yhat = it.predictRaw(inputData)
println(yhat.values.toTypedArray().contentDeepToString())

val rawPoseLandMarks = (yhat["output_0"] as Array<Array<Array<FloatArray>>>)[0][0]
val rawPoseLandMarks = it.predictRaw(inputData) { result ->
result.get2DFloatArray("output_0")
}
println(rawPoseLandMarks.contentDeepToString())

// Dictionary that maps from joint names to keypoint indices.
val keypoints = mapOf(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package examples.onnx.faces
import examples.transferlearning.getFileFromResource
import org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub
import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels
import org.jetbrains.kotlinx.dl.api.inference.onnx.OrtSessionResultConversions.getFloatArray
import org.jetbrains.kotlinx.dl.dataset.image.ColorMode
import org.jetbrains.kotlinx.dl.dataset.preprocessing.call
import org.jetbrains.kotlinx.dl.dataset.preprocessing.pipeline
Expand Down Expand Up @@ -56,8 +57,8 @@ class FacesTestSuite {
val imageFile = getFileFromResource("datasets/faces/image$i.jpg")
val inputData = fileDataLoader.load(imageFile).first

val yhat = it.predictRaw(inputData)
assertEquals(212, (yhat.values.toTypedArray()[0] as Array<FloatArray>)[0].size)
val yhat = it.predictRaw(inputData) { output -> output.getFloatArray(0) }
assertEquals(212, yhat.size)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package examples.onnx.posedetection
import examples.transferlearning.getFileFromResource
import org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub
import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels
import org.jetbrains.kotlinx.dl.api.inference.onnx.OrtSessionResultConversions.get2DFloatArray
import org.jetbrains.kotlinx.dl.dataset.image.ColorMode
import org.jetbrains.kotlinx.dl.dataset.preprocessing.call
import org.jetbrains.kotlinx.dl.dataset.preprocessing.pipeline
Expand Down Expand Up @@ -84,9 +85,9 @@ class PoseDetectionTestSuite {

val inputData = fileDataLoader.load(imageFile).first

val yhat = it.predictRaw(inputData)

val rawPoseLandMarks = (yhat["output_0"] as Array<Array<Array<FloatArray>>>)[0][0]
val rawPoseLandMarks = it.predictRaw(inputData) { result ->
result.get2DFloatArray("output_0")
}

assertEquals(17, rawPoseLandMarks.size)
}
Expand All @@ -113,9 +114,9 @@ class PoseDetectionTestSuite {

val inputData = preprocessing.load(imageFile).first

val yhat = it.predictRaw(inputData)

val rawPoseLandMarks = (yhat["output_0"] as Array<Array<Array<FloatArray>>>)[0][0]
val rawPoseLandMarks = it.predictRaw(inputData) { result ->
result.get2DFloatArray("output_0")
}

assertEquals(17, rawPoseLandMarks.size)
}
Expand All @@ -142,10 +143,10 @@ class PoseDetectionTestSuite {


val inputData = dataLoader.load(imageFile).first
val yhat = inferenceModel.predictRaw(inputData)
println(yhat.values.toTypedArray().contentDeepToString())

val rawPosesLandMarks = (yhat["output_0"] as Array<Array<FloatArray>>)[0]
val rawPosesLandMarks = inferenceModel.predictRaw(inputData) { result ->
result.get2DFloatArray("output_0")
}
println(rawPosesLandMarks.contentDeepToString())

assertEquals(6, rawPosesLandMarks.size)
rawPosesLandMarks.forEach {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.jetbrains.kotlinx.dl.api.inference.onnx

import ai.onnxruntime.OrtSession
import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionProvider
import org.jetbrains.kotlinx.dl.dataset.preprocessing.Operation
import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape
Expand All @@ -29,15 +30,14 @@ public interface OnnxHighLevelModel<I, R> : ExecutionProviderCompatible {
/**
* Converts raw model output to the result.
*/
public fun convert(output: Map<String, Any>): R
public fun convert(output: OrtSession.Result): R

/**
* Makes prediction on the given [input].
*/
public fun predict(input: I): R {
val preprocessedInput = preprocessing.apply(input)
val output = internalModel.predictRaw(preprocessedInput.first)
return convert(output)
return internalModel.predictRaw(preprocessedInput.first) { convert(it) }
}

override fun initializeWith(vararg executionProviders: ExecutionProvider) {
Expand Down
Loading

0 comments on commit 4e2e52e

Please sign in to comment.