diff --git a/examples/src/main/kotlin/examples/onnx/faces/Fan2D106.kt b/examples/src/main/kotlin/examples/onnx/faces/Fan2D106.kt index 949291395..ec9fe4e8b 100644 --- a/examples/src/main/kotlin/examples/onnx/faces/Fan2D106.kt +++ b/examples/src/main/kotlin/examples/onnx/faces/Fan2D106.kt @@ -9,6 +9,7 @@ import examples.transferlearning.getFileFromResource import org.jetbrains.kotlinx.dl.api.inference.facealignment.Landmark import org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels +import org.jetbrains.kotlinx.dl.api.inference.onnx.OrtSessionResultConversions.getFloatArray import org.jetbrains.kotlinx.dl.api.inference.onnx.facealignment.Fan2D106FaceAlignmentModel import org.jetbrains.kotlinx.dl.dataset.image.ColorMode import org.jetbrains.kotlinx.dl.dataset.image.ImageConverter @@ -53,11 +54,10 @@ fun main() { val inputImage = ImageConverter.toBufferedImage(inputFile) val inputData = preprocessor.apply(inputImage).first - val yhat = it.predictRaw(inputData) - println(yhat.values.toTypedArray().contentDeepToString()) + val floats = it.predictRaw(inputData) { output -> output.getFloatArray("fc1") } + println(floats.contentToString()) val landMarks = mutableListOf() - val floats = (yhat["fc1"] as Array<*>)[0] as FloatArray for (j in floats.indices step 2) { landMarks.add(Landmark((1 + floats[j]) / 2, (1 + floats[j + 1]) / 2)) } diff --git a/examples/src/main/kotlin/examples/onnx/objectdetection/ssd/SSD.kt b/examples/src/main/kotlin/examples/onnx/objectdetection/ssd/SSD.kt index 7ee927633..8bb9cad5f 100644 --- a/examples/src/main/kotlin/examples/onnx/objectdetection/ssd/SSD.kt +++ b/examples/src/main/kotlin/examples/onnx/objectdetection/ssd/SSD.kt @@ -8,6 +8,7 @@ package examples.onnx.objectdetection.ssd import examples.transferlearning.getFileFromResource import org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels +import org.jetbrains.kotlinx.dl.api.inference.onnx.OrtSessionResultConversions.getFloatArray import org.jetbrains.kotlinx.dl.dataset.image.ColorMode import org.jetbrains.kotlinx.dl.dataset.preprocessing.call import org.jetbrains.kotlinx.dl.dataset.preprocessing.pipeline @@ -45,10 +46,10 @@ fun ssd() { val inputData = preprocessing.load(getFileFromResource("datasets/detection/image$i.jpg")).first val start = System.currentTimeMillis() - val yhat = it.predictRaw(inputData) + val yhat = it.predictRaw(inputData) { output -> output.getFloatArray(0)} val end = System.currentTimeMillis() println("Prediction took ${end - start} ms") - println(yhat.values.toTypedArray().contentDeepToString()) + println(yhat.contentToString()) } } } diff --git a/examples/src/main/kotlin/examples/onnx/posedetection/multipose/multiPoseDetectionMoveNet.kt b/examples/src/main/kotlin/examples/onnx/posedetection/multipose/multiPoseDetectionMoveNet.kt index a04ebf918..dae7a1256 100644 --- a/examples/src/main/kotlin/examples/onnx/posedetection/multipose/multiPoseDetectionMoveNet.kt +++ b/examples/src/main/kotlin/examples/onnx/posedetection/multipose/multiPoseDetectionMoveNet.kt @@ -9,6 +9,7 @@ import examples.transferlearning.getFileFromResource import org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub import org.jetbrains.kotlinx.dl.api.inference.objectdetection.DetectedObject import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels +import org.jetbrains.kotlinx.dl.api.inference.onnx.OrtSessionResultConversions.get2DFloatArray import org.jetbrains.kotlinx.dl.api.inference.posedetection.DetectedPose import org.jetbrains.kotlinx.dl.api.inference.posedetection.MultiPoseDetectionResult import org.jetbrains.kotlinx.dl.api.inference.posedetection.PoseLandmark @@ -51,10 +52,11 @@ fun multiPoseDetectionMoveNet() { .call(modelType.preprocessor) val inputData = preprocessor.apply(inputImage).first - val yhat = it.predictRaw(inputData) - println(yhat.values.toTypedArray().contentDeepToString()) + val rawPoseLandmarks = it.predictRaw(inputData) { result -> + result.get2DFloatArray("output_0") + } + println(rawPoseLandmarks.contentDeepToString()) - val rawPoseLandmarks = (yhat["output_0"] as Array>)[0] val poses = rawPoseLandmarks.mapNotNull { floats -> val probability = floats[55] if (probability < 0.05) return@mapNotNull null diff --git a/examples/src/main/kotlin/examples/onnx/posedetection/singlepose/poseDetectionMoveNet.kt b/examples/src/main/kotlin/examples/onnx/posedetection/singlepose/poseDetectionMoveNet.kt index b0cc89cb3..ed47fd9d0 100644 --- a/examples/src/main/kotlin/examples/onnx/posedetection/singlepose/poseDetectionMoveNet.kt +++ b/examples/src/main/kotlin/examples/onnx/posedetection/singlepose/poseDetectionMoveNet.kt @@ -8,6 +8,7 @@ package examples.onnx.posedetection.singlepose import examples.transferlearning.getFileFromResource import org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels +import org.jetbrains.kotlinx.dl.api.inference.onnx.OrtSessionResultConversions.get2DFloatArray import org.jetbrains.kotlinx.dl.api.inference.posedetection.DetectedPose import org.jetbrains.kotlinx.dl.api.inference.posedetection.PoseLandmark import org.jetbrains.kotlinx.dl.dataset.image.ColorMode @@ -49,10 +50,10 @@ fun poseDetectionMoveNet() { val inputData = preprocessing.apply(image).first - val yhat = it.predictRaw(inputData) - println(yhat.values.toTypedArray().contentDeepToString()) - - val rawPoseLandMarks = (yhat["output_0"] as Array>>)[0][0] + val rawPoseLandMarks = it.predictRaw(inputData) { result -> + result.get2DFloatArray("output_0") + } + println(rawPoseLandMarks.contentDeepToString()) // Dictionary that maps from joint names to keypoint indices. val keypoints = mapOf( diff --git a/examples/src/test/kotlin/examples/onnx/faces/FacesTestSuite.kt b/examples/src/test/kotlin/examples/onnx/faces/FacesTestSuite.kt index b45d11970..d00449140 100644 --- a/examples/src/test/kotlin/examples/onnx/faces/FacesTestSuite.kt +++ b/examples/src/test/kotlin/examples/onnx/faces/FacesTestSuite.kt @@ -8,6 +8,7 @@ package examples.onnx.faces import examples.transferlearning.getFileFromResource import org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels +import org.jetbrains.kotlinx.dl.api.inference.onnx.OrtSessionResultConversions.getFloatArray import org.jetbrains.kotlinx.dl.dataset.image.ColorMode import org.jetbrains.kotlinx.dl.dataset.preprocessing.call import org.jetbrains.kotlinx.dl.dataset.preprocessing.pipeline @@ -56,8 +57,8 @@ class FacesTestSuite { val imageFile = getFileFromResource("datasets/faces/image$i.jpg") val inputData = fileDataLoader.load(imageFile).first - val yhat = it.predictRaw(inputData) - assertEquals(212, (yhat.values.toTypedArray()[0] as Array)[0].size) + val yhat = it.predictRaw(inputData) { output -> output.getFloatArray(0) } + assertEquals(212, yhat.size) } } } diff --git a/examples/src/test/kotlin/examples/onnx/posedetection/PoseDetectionTestSuite.kt b/examples/src/test/kotlin/examples/onnx/posedetection/PoseDetectionTestSuite.kt index 154cf02ba..b0c608ae7 100644 --- a/examples/src/test/kotlin/examples/onnx/posedetection/PoseDetectionTestSuite.kt +++ b/examples/src/test/kotlin/examples/onnx/posedetection/PoseDetectionTestSuite.kt @@ -8,6 +8,7 @@ package examples.onnx.posedetection import examples.transferlearning.getFileFromResource import org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels +import org.jetbrains.kotlinx.dl.api.inference.onnx.OrtSessionResultConversions.get2DFloatArray import org.jetbrains.kotlinx.dl.dataset.image.ColorMode import org.jetbrains.kotlinx.dl.dataset.preprocessing.call import org.jetbrains.kotlinx.dl.dataset.preprocessing.pipeline @@ -84,9 +85,9 @@ class PoseDetectionTestSuite { val inputData = fileDataLoader.load(imageFile).first - val yhat = it.predictRaw(inputData) - - val rawPoseLandMarks = (yhat["output_0"] as Array>>)[0][0] + val rawPoseLandMarks = it.predictRaw(inputData) { result -> + result.get2DFloatArray("output_0") + } assertEquals(17, rawPoseLandMarks.size) } @@ -113,9 +114,9 @@ class PoseDetectionTestSuite { val inputData = preprocessing.load(imageFile).first - val yhat = it.predictRaw(inputData) - - val rawPoseLandMarks = (yhat["output_0"] as Array>>)[0][0] + val rawPoseLandMarks = it.predictRaw(inputData) { result -> + result.get2DFloatArray("output_0") + } assertEquals(17, rawPoseLandMarks.size) } @@ -142,10 +143,10 @@ class PoseDetectionTestSuite { val inputData = dataLoader.load(imageFile).first - val yhat = inferenceModel.predictRaw(inputData) - println(yhat.values.toTypedArray().contentDeepToString()) - - val rawPosesLandMarks = (yhat["output_0"] as Array>)[0] + val rawPosesLandMarks = inferenceModel.predictRaw(inputData) { result -> + result.get2DFloatArray("output_0") + } + println(rawPosesLandMarks.contentDeepToString()) assertEquals(6, rawPosesLandMarks.size) rawPosesLandMarks.forEach { diff --git a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/OnnxHighLevelModel.kt b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/OnnxHighLevelModel.kt index 77ef91e0f..a93ef3429 100644 --- a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/OnnxHighLevelModel.kt +++ b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/OnnxHighLevelModel.kt @@ -5,6 +5,7 @@ package org.jetbrains.kotlinx.dl.api.inference.onnx +import ai.onnxruntime.OrtSession import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionProvider import org.jetbrains.kotlinx.dl.dataset.preprocessing.Operation import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape @@ -29,15 +30,14 @@ public interface OnnxHighLevelModel : ExecutionProviderCompatible { /** * Converts raw model output to the result. */ - public fun convert(output: Map): R + public fun convert(output: OrtSession.Result): R /** * Makes prediction on the given [input]. */ public fun predict(input: I): R { val preprocessedInput = preprocessing.apply(input) - val output = internalModel.predictRaw(preprocessedInput.first) - return convert(output) + return internalModel.predictRaw(preprocessedInput.first) { convert(it) } } override fun initializeWith(vararg executionProviders: ExecutionProvider) { diff --git a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/OnnxInferenceModel.kt b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/OnnxInferenceModel.kt index 96bd235a2..aea5d3d40 100644 --- a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/OnnxInferenceModel.kt +++ b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/OnnxInferenceModel.kt @@ -7,16 +7,16 @@ package org.jetbrains.kotlinx.dl.api.inference.onnx import ai.onnxruntime.* import ai.onnxruntime.OrtSession.SessionOptions -import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape import org.jetbrains.kotlinx.dl.api.extension.argmax import org.jetbrains.kotlinx.dl.api.inference.InferenceModel +import org.jetbrains.kotlinx.dl.api.inference.onnx.OrtSessionResultConversions.getFloatArray +import org.jetbrains.kotlinx.dl.api.inference.onnx.OrtSessionResultConversions.getValues +import org.jetbrains.kotlinx.dl.api.inference.onnx.OrtSessionResultConversions.throwIfOutputNotSupported import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionProvider import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionProvider.CPU +import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape import java.nio.* -import java.util.* - -private const val RESHAPE_MISSED_MESSAGE = "Model input shape is not defined. Call reshape() to set input shape." /** * Inference model built on ONNX format. @@ -126,8 +126,7 @@ public open class OnnxInferenceModel private constructor(private val modelSource private fun initInputOutputInfo() { val inputTensorInfo = session.inputInfo.toList()[0].second.info as TensorInfo if (!::inputShape.isInitialized) { - val inputDims = - inputTensorInfo.shape.takeLast(3).toLongArray() + val inputDims = inputTensorInfo.shape.takeLast(3).toLongArray() inputShape = TensorShape(1, *inputDims).dims() } inputDataType = inputTensorInfo.type @@ -199,32 +198,16 @@ public open class OnnxInferenceModel private constructor(private val modelSource } override fun predictSoftly(inputData: FloatArray, predictionTensorName: String): FloatArray { - require(::inputShape.isInitialized) { RESHAPE_MISSED_MESSAGE } - - val outputTensorName = when { - predictionTensorName.isEmpty() -> session.outputNames.first() - else -> predictionTensorName + val outputTensorName = predictionTensorName.ifEmpty { session.outputNames.first() } + require(outputTensorName in session.outputInfo) { + "There is no output with name '$outputTensorName'." + + " The model only has following outputs - ${session.outputInfo.keys}" } - require(outputTensorName in session.outputInfo) { "There is no output with name '$outputTensorName'. The model only has following outputs - ${session.outputInfo.keys}" } - - throwIfOutputNotSupported(outputTensorName, "predictSoftly") - - val outputIdx = session.outputInfo.keys.indexOf(outputTensorName) + val outputInfo = session.outputInfo.getValue(outputTensorName).info + throwIfOutputNotSupported(outputInfo, outputTensorName, "predictSoftly", OnnxJavaType.FLOAT) - return predictSoftly(inputData, outputIdx) - } - - /** - * Currently, some methods only support float tensors as model output. - * This method checks if model output satisfies these requirements. - */ - // TODO: add support for all ONNX output types (see https://github.com/Kotlin/kotlindl/issues/367) - private fun throwIfOutputNotSupported(outputName: String, method: String) { - val outputInfo = session.outputInfo[outputName]!!.info - require(outputInfo !is MapInfo) { "Output $outputName is a Map, but currently method $method supports only float Tensor outputs. Please use predictRaw method instead." } - require(outputInfo !is SequenceInfo) { "Output '$outputName' is a Sequence, but currently method $method supports only float Tensor outputs. Please use predictRaw method instead." } - require(outputInfo is TensorInfo && outputInfo.type == OnnxJavaType.FLOAT) { "Currently method $method supports only float Tensor outputs, but output '$outputName' is not a float Tensor. Please use predictRaw method instead." } + return predictRaw(inputData) { output -> output.getFloatArray(outputTensorName) } } /** @@ -237,91 +220,28 @@ public open class OnnxInferenceModel private constructor(private val modelSource return predictSoftly(inputData, session.outputNames.first()) } - private fun predictSoftly(inputData: FloatArray, outputTensorIdx: Int): FloatArray { - val inputTensor = createInputTensor(inputData) - - val outputTensor = session.run(Collections.singletonMap(session.inputNames.toList()[0], inputTensor)) - - val outputInfo = session.outputInfo.toList()[outputTensorIdx].second.info as TensorInfo - - val outputProbs: FloatArray = when { - outputInfo.shape.size > 1 -> (outputTensor[outputTensorIdx].value as Array)[0] - else -> outputTensor[outputTensorIdx].value as FloatArray - } - - outputTensor.close() - inputTensor.close() - - return outputProbs - } - /** * Returns list of multidimensional arrays with data from model outputs. * * NOTE: This operation can be quite slow for high dimensional tensors, - * you should prefer [predictRawWithShapes] in this case. + * use [predictRaw] with custom output processing for better performance. */ public fun predictRaw(inputData: FloatArray): Map { - require(::inputShape.isInitialized) { RESHAPE_MISSED_MESSAGE } - - val inputTensor = createInputTensor(inputData) - - val outputTensor = session.run(Collections.singletonMap(session.inputNames.toList()[0], inputTensor)) - - val result = mutableMapOf() - - outputTensor.forEach { - result[it.key] = it.value.value - } - - outputTensor.close() - inputTensor.close() - - return result.toMap() + return predictRaw(inputData) { it.getValues() } } - // TODO: refactor predictRaw and predictRawWithShapes to extract the common functionality - /** - * Returns list of pairs from model outputs. + * Runs prediction on a given [inputData] and calls [extractResult] function to process output. + * @see OrtSessionResultConversions */ - // TODO: add tests for many available models - // TODO: return map - public fun predictRawWithShapes(inputData: FloatArray): List> { - require(::inputShape.isInitialized) { RESHAPE_MISSED_MESSAGE } - - session.outputInfo.keys.forEach { - throwIfOutputNotSupported(it, "predictRawWithShapes") - } + public fun predictRaw(inputData: FloatArray, extractResult: (OrtSession.Result) -> R): R { + require(::inputShape.isInitialized) { "Model input shape is not defined. Call reshape() to set input shape." } - val preparedData = FloatBuffer.wrap(inputData) - - val tensor = OnnxTensor.createTensor(env, preparedData, inputShape) - val output = session.run(Collections.singletonMap(session.inputNames.toList()[0], tensor)) - - val result = mutableListOf>() - - output.forEach { - val onnxTensorShape = (it.value.info as TensorInfo).shape - result.add(Pair((it.value as OnnxTensor).floatBuffer, onnxTensorShape)) + return env.createTensor(inputData, inputDataType, inputShape).use { inputTensor -> + session.run(mapOf(session.inputNames.first() to inputTensor)).use { output -> + extractResult(output) + } } - - output.close() - tensor.close() - - return result.toList() - } - - /** - * Predicts the class of [inputData]. - * - * @param [inputData] The single example with unknown label. - * @param [inputTensorName] The name of input tensor. - * @param [outputTensorName] The name of output tensor. - * @return Predicted class index. - */ - public fun predict(inputData: FloatArray, inputTensorName: String, outputTensorName: String): Int { - TODO("ONNX doesn't support extraction outputs from the intermediate levels of the model.") } override fun copy( @@ -355,54 +275,51 @@ public open class OnnxInferenceModel private constructor(private val modelSource return "OnnxModel(session=$session)" } - private fun createInputTensor(inputData: FloatArray): OnnxTensor { - val inputTensor = when (inputDataType) { - OnnxJavaType.FLOAT -> OnnxTensor.createTensor(env, FloatBuffer.wrap(inputData), inputShape) - OnnxJavaType.DOUBLE -> OnnxTensor.createTensor( - env, - DoubleBuffer.wrap(inputData.map { it.toDouble() }.toDoubleArray()), - inputShape - ) - - OnnxJavaType.INT8 -> OnnxTensor.createTensor( - env, - ByteBuffer.wrap(inputData.map { it.toInt().toByte() }.toByteArray()), - inputShape - ) - - OnnxJavaType.INT16 -> OnnxTensor.createTensor( - env, - ShortBuffer.wrap(inputData.map { it.toInt().toShort() }.toShortArray()), - inputShape - ) - - OnnxJavaType.INT32 -> OnnxTensor.createTensor( - env, - IntBuffer.wrap(inputData.map { it.toInt() }.toIntArray()), - inputShape - ) - - OnnxJavaType.INT64 -> OnnxTensor.createTensor( - env, - LongBuffer.wrap(inputData.map { it.toLong() }.toLongArray()), - inputShape - ) - - OnnxJavaType.STRING -> TODO() - OnnxJavaType.UINT8 -> OnnxTensor.createTensor( - env, - ByteBuffer.wrap(inputData.map { it.toInt().toUByte().toByte() }.toByteArray()), - inputShape, - OnnxJavaType.UINT8 - ) - - OnnxJavaType.UNKNOWN -> TODO() - else -> TODO() + public companion object { + private fun OrtEnvironment.createTensor(data: FloatArray, + dataType: OnnxJavaType, + shape: LongArray + ): OnnxTensor { + val inputTensor = when (dataType) { + OnnxJavaType.FLOAT -> OnnxTensor.createTensor(this, FloatBuffer.wrap(data), shape) + OnnxJavaType.DOUBLE -> OnnxTensor.createTensor( + this, + DoubleBuffer.wrap(data.map { it.toDouble() }.toDoubleArray()), + shape + ) + OnnxJavaType.INT8 -> OnnxTensor.createTensor( + this, + ByteBuffer.wrap(data.map { it.toInt().toByte() }.toByteArray()), + shape + ) + OnnxJavaType.INT16 -> OnnxTensor.createTensor( + this, + ShortBuffer.wrap(data.map { it.toInt().toShort() }.toShortArray()), + shape + ) + OnnxJavaType.INT32 -> OnnxTensor.createTensor( + this, + IntBuffer.wrap(data.map { it.toInt() }.toIntArray()), + shape + ) + OnnxJavaType.INT64 -> OnnxTensor.createTensor( + this, + LongBuffer.wrap(data.map { it.toLong() }.toLongArray()), + shape + ) + OnnxJavaType.STRING -> TODO() + OnnxJavaType.UINT8 -> OnnxTensor.createTensor( + this, + ByteBuffer.wrap(data.map { it.toInt().toUByte().toByte() }.toByteArray()), + shape, + OnnxJavaType.UINT8 + ) + OnnxJavaType.UNKNOWN -> TODO() + else -> TODO() + } + return inputTensor } - return inputTensor - } - public companion object { /** * Loads model from serialized ONNX file. */ diff --git a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/OrtSessionResultConversions.kt b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/OrtSessionResultConversions.kt new file mode 100644 index 000000000..d3c07c00b --- /dev/null +++ b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/OrtSessionResultConversions.kt @@ -0,0 +1,269 @@ +/* + * Copyright 2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved. + * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file. + */ + +package org.jetbrains.kotlinx.dl.api.inference.onnx + +import ai.onnxruntime.* + +/** + * Convenience functions for processing [OrtSession.Result]. + */ +public object OrtSessionResultConversions { + /** + * Returns the output at [index] as a [FloatArray] with its shape. + */ + public fun OrtSession.Result.getFloatArrayWithShape(index: Int): Pair { + return get(index).getFloatArrayWithShape() + } + + /** + * Returns the output at [index] as a [FloatArray]. + */ + public fun OrtSession.Result.getFloatArray(index: Int): FloatArray { + return getFloatArrayWithShape(index).first + } + + /** + * Returns the output by [name] as a [FloatArray] with its shape. + */ + public fun OrtSession.Result.getFloatArrayWithShape(name: String): Pair { + return get(name).get().getFloatArrayWithShape() + } + + /** + * Returns the output by [name] as a [FloatArray]. + */ + public fun OrtSession.Result.getFloatArray(name: String): FloatArray { + return getFloatArrayWithShape(name).first + } + + private fun OnnxValue.getFloatArrayWithShape(): Pair { + throwIfOutputNotSupported(info, toString(), "getFloatArray", OnnxJavaType.FLOAT) + val shape = (info as TensorInfo).shape + return (this as OnnxTensor).floatBuffer.array() to shape + } + + /** + * Returns the output at [index] as a [DoubleArray] with its shape. + */ + public fun OrtSession.Result.getDoubleArrayWithShape(index: Int): Pair { + return get(index).getDoubleArrayWithShape() + } + + /** + * Returns the output at [index] as a [DoubleArray]. + */ + public fun OrtSession.Result.getDoubleArray(index: Int): DoubleArray { + return getDoubleArrayWithShape(index).first + } + + /** + * Returns the output by [name] as a [DoubleArray] with its shape. + */ + public fun OrtSession.Result.getDoubleArrayWithShape(name: String): Pair { + return get(name).get().getDoubleArrayWithShape() + } + + /** + * Returns the output by [name] as a [DoubleArray]. + */ + public fun OrtSession.Result.getDoubleArray(name: String): DoubleArray { + return getDoubleArrayWithShape(name).first + } + + private fun OnnxValue.getDoubleArrayWithShape(): Pair { + throwIfOutputNotSupported(info, toString(), "getDoubleArray", OnnxJavaType.DOUBLE) + val shape = (info as TensorInfo).shape + return (this as OnnxTensor).doubleBuffer.array() to shape + } + + /** + * Returns the output at [index] as a [LongArray] with its shape. + */ + public fun OrtSession.Result.getLongArrayWithShape(index: Int): Pair { + return get(index).getLongArrayWithShape() + } + + /** + * Returns the output at [index] as a [LongArray]. + */ + public fun OrtSession.Result.getLongArray(index: Int): LongArray { + return getLongArrayWithShape(index).first + } + + /** + * Returns the output by [name] as a [LongArray] with its shape. + */ + public fun OrtSession.Result.getLongArrayWithShape(name: String): Pair { + return get(name).get().getLongArrayWithShape() + } + + /** + * Returns the output by [name] as a [FloatArray]. + */ + public fun OrtSession.Result.getLongArray(name: String): LongArray { + return getLongArrayWithShape(name).first + } + + private fun OnnxValue.getLongArrayWithShape(): Pair { + throwIfOutputNotSupported(info, toString(), "getLongArray", OnnxJavaType.INT64) + val shape = (info as TensorInfo).shape + return (this as OnnxTensor).longBuffer.array() to shape + } + + /** + * Returns the output at [index] as an [IntArray] with its shape. + */ + public fun OrtSession.Result.getIntArrayWithShape(index: Int): Pair { + return get(index).getIntArrayWithShape() + } + + /** + * Returns the output at [index] as an [IntArray]. + */ + public fun OrtSession.Result.getIntArray(index: Int): IntArray { + return getIntArrayWithShape(index).first + } + + /** + * Returns the output by [name] as an [IntArray] with its shape. + */ + public fun OrtSession.Result.getIntArrayWithShape(name: String): Pair { + return get(name).get().getIntArrayWithShape() + } + + /** + * Returns the output by [name] as an [IntArray]. + */ + public fun OrtSession.Result.getIntArray(name: String): IntArray { + return getIntArrayWithShape(name).first + } + + private fun OnnxValue.getIntArrayWithShape(): Pair { + throwIfOutputNotSupported(info, toString(), "getIntArray", OnnxJavaType.INT32) + val shape = (info as TensorInfo).shape + return (this as OnnxTensor).intBuffer.array() to shape + } + + /** + * Returns the output at [index] as a [ShortArray] with its shape. + */ + public fun OrtSession.Result.getShortArrayWithShape(index: Int): Pair { + return get(index).getShortArrayWithShape() + } + + /** + * Returns the output at [index] as a [ShortArray]. + */ + public fun OrtSession.Result.getShortArray(index: Int): ShortArray { + return getShortArrayWithShape(index).first + } + + /** + * Returns the output by [name] as a [ShortArray] with its shape. + */ + public fun OrtSession.Result.getShortArrayWithShape(name: String): Pair { + return get(name).get().getShortArrayWithShape() + } + + /** + * Returns the output by [name] as a [ShortArray]. + */ + public fun OrtSession.Result.getShortArray(name: String): ShortArray { + return getShortArrayWithShape(name).first + } + + private fun OnnxValue.getShortArrayWithShape(): Pair { + throwIfOutputNotSupported(info, toString(), "getShortArray", OnnxJavaType.INT16) + val shape = (info as TensorInfo).shape + return (this as OnnxTensor).shortBuffer.array() to shape + } + + /** + * Returns the output at [index] as a [ByteArray] with its shape. + */ + public fun OrtSession.Result.getByteArrayWithShape(index: Int): Pair { + return get(index).getByteArrayWithShape() + } + + /** + * Returns the output at [index] as a [ByteArray]. + */ + public fun OrtSession.Result.getByteArray(index: Int): ByteArray { + return getByteArrayWithShape(index).first + } + + /** + * Returns the output by [name] as a [ByteArray] with its shape. + */ + public fun OrtSession.Result.getByteArrayWithShape(name: String): Pair { + return get(name).get().getByteArrayWithShape() + } + + /** + * Returns the output by [name] as a [ByteArray]. + */ + public fun OrtSession.Result.getByteArray(name: String): ByteArray { + return getByteArrayWithShape(name).first + } + + private fun OnnxValue.getByteArrayWithShape(): Pair { + throwIfOutputNotSupported(info, toString(), "getByteArray", OnnxJavaType.STRING) + val shape = (info as TensorInfo).shape + return (this as OnnxTensor).byteBuffer.array() to shape + } + + /** + * Returns the output by [name] as an Array. This operation could be slow for high dimensional tensors, + * in which case [getFloatArray] should be used. + */ + public fun OrtSession.Result.get2DFloatArray(name: String): Array { + return get(name).get().get2DFloatArray() + } + + /** + * Returns the output at [index] as an Array. This operation could be slow for high dimensional tensors, + * in which case [getFloatArray] should be used. + */ + public fun OrtSession.Result.get2DFloatArray(index: Int): Array { + return get(index).get2DFloatArray() + } + + @Suppress("UNCHECKED_CAST") + private fun OnnxValue.get2DFloatArray(): Array { + throwIfOutputNotSupported(info, toString(), "get2DFloatArray", OnnxJavaType.FLOAT) + val shape = (info as TensorInfo).shape + val depth = shape.size - 2 + require(depth >= 0 && shape.slice(0 until depth).all { it == 1L }) { + "Output of shape $shape can't be converted to the Array." + } + var result = value as Array<*> + repeat(depth) { + result = result[0] as Array<*> + } + return result as Array + } + + /** + * Returns all values from this [OrtSession.Result]. This operation could be slow for high dimensional tensors, + * in which case functions that return one dimensional array such as [getFloatArray] or [getLongArray] should be used. + * @see OnnxValue.getValue + */ + public fun OrtSession.Result.getValues(): Map = associate { it.key to it.value.value } + + /** + * Checks if [valueInfo] corresponds to a Tensor of the specified [type]. + * If it does not satisfy the requirements, exception with a message containing [valueName] and calling [method] name is thrown. + */ + internal fun throwIfOutputNotSupported(valueInfo: ValueInfo, + valueName: String, + method: String, + type: OnnxJavaType + ) { + require(valueInfo !is MapInfo) { "Output $valueName is a Map, but currently method $method supports only $type Tensor outputs." } + require(valueInfo !is SequenceInfo) { "Output '$valueName' is a Sequence, but currently method $method supports $type float Tensor outputs." } + require(valueInfo is TensorInfo && valueInfo.type == type) { "Currently method $method supports only $type Tensor outputs, but output '$valueName' is not a float Tensor." } + } +} \ No newline at end of file diff --git a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/facealignment/FaceAlignmentModelBase.kt b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/facealignment/FaceAlignmentModelBase.kt index 7f3f3c2a0..e534a6830 100644 --- a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/facealignment/FaceAlignmentModelBase.kt +++ b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/facealignment/FaceAlignmentModelBase.kt @@ -5,8 +5,10 @@ package org.jetbrains.kotlinx.dl.api.inference.onnx.facealignment +import ai.onnxruntime.OrtSession import org.jetbrains.kotlinx.dl.api.inference.facealignment.Landmark import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxHighLevelModel +import org.jetbrains.kotlinx.dl.api.inference.onnx.OrtSessionResultConversions.getFloatArray /** * Base class for face alignment models. @@ -17,9 +19,9 @@ public abstract class FaceAlignmentModelBase : OnnxHighLevelModel): List { + override fun convert(output: OrtSession.Result): List { val landMarks = mutableListOf() - val floats = (output[outputName] as Array<*>)[0] as FloatArray + val floats = output.getFloatArray(outputName) for (i in floats.indices step 2) { landMarks.add(Landmark((1 + floats[i]) / 2, (1 + floats[i + 1]) / 2)) } diff --git a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/facealignment/FaceDetectionModelBase.kt b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/facealignment/FaceDetectionModelBase.kt index 400716be2..2b201ab4d 100644 --- a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/facealignment/FaceDetectionModelBase.kt +++ b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/facealignment/FaceDetectionModelBase.kt @@ -5,8 +5,10 @@ package org.jetbrains.kotlinx.dl.api.inference.onnx.facealignment +import ai.onnxruntime.OrtSession import org.jetbrains.kotlinx.dl.api.inference.objectdetection.DetectedObject import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxHighLevelModel +import org.jetbrains.kotlinx.dl.api.inference.onnx.OrtSessionResultConversions.get2DFloatArray import java.lang.Float.min import kotlin.math.max @@ -15,9 +17,9 @@ import kotlin.math.max */ public abstract class FaceDetectionModelBase : OnnxHighLevelModel> { - override fun convert(output: Map): List { - val scores = (output["scores"] as Array<*>)[0] as Array - val boxes = (output["boxes"] as Array<*>)[0] as Array + override fun convert(output: OrtSession.Result): List { + val scores = output.get2DFloatArray("scores") + val boxes = output.get2DFloatArray("boxes") if (scores.isEmpty()) return emptyList() diff --git a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/ObjectDetectionModelBase.kt b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/ObjectDetectionModelBase.kt index bf7cf8522..97341664b 100644 --- a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/ObjectDetectionModelBase.kt +++ b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/ObjectDetectionModelBase.kt @@ -5,8 +5,11 @@ package org.jetbrains.kotlinx.dl.api.inference.onnx.objectdetection +import ai.onnxruntime.OrtSession import org.jetbrains.kotlinx.dl.api.inference.objectdetection.DetectedObject import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxHighLevelModel +import org.jetbrains.kotlinx.dl.api.inference.onnx.OrtSessionResultConversions.get2DFloatArray +import org.jetbrains.kotlinx.dl.api.inference.onnx.OrtSessionResultConversions.getFloatArray /** * Base class for object detection models. @@ -38,9 +41,9 @@ public abstract class ObjectDetectionModelBase : OnnxHighLevelModel : ObjectDetectionModelBase() { - override fun convert(output: Map): List { + override fun convert(output: OrtSession.Result): List { val foundObjects = mutableListOf() - val items = (output[OUTPUT_NAME] as Array>)[0] + val items = output.get2DFloatArray(OUTPUT_NAME) for (i in items.indices) { val probability = items[i][5] @@ -69,10 +72,10 @@ public abstract class EfficientDetObjectDetectionModelBase : ObjectDetectionM * Base class for object detection model based on SSD architecture. */ public abstract class SSDLikeModelBase(protected val metadata: SSDLikeModelMetadata) : ObjectDetectionModelBase() { - override fun convert(output: Map): List { - val boxes = (output[metadata.outputBoxesName] as Array>)[0] - val classIndices = (output[metadata.outputClassesName] as Array)[0] - val probabilities = (output[metadata.outputScoresName] as Array)[0] + override fun convert(output: OrtSession.Result): List { + val boxes = output.get2DFloatArray(metadata.outputBoxesName) + val classIndices = output.getFloatArray(metadata.outputClassesName) + val probabilities = output.getFloatArray(metadata.outputScoresName) val numberOfFoundObjects = boxes.size val foundObjects = mutableListOf() diff --git a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/posedetection/MultiPoseDetectionModelBase.kt b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/posedetection/MultiPoseDetectionModelBase.kt index be4cccd7c..f8ccec2a2 100644 --- a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/posedetection/MultiPoseDetectionModelBase.kt +++ b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/posedetection/MultiPoseDetectionModelBase.kt @@ -5,8 +5,10 @@ package org.jetbrains.kotlinx.dl.api.inference.onnx.posedetection +import ai.onnxruntime.OrtSession import org.jetbrains.kotlinx.dl.api.inference.objectdetection.DetectedObject import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxHighLevelModel +import org.jetbrains.kotlinx.dl.api.inference.onnx.OrtSessionResultConversions.get2DFloatArray import org.jetbrains.kotlinx.dl.api.inference.posedetection.DetectedPose import org.jetbrains.kotlinx.dl.api.inference.posedetection.MultiPoseDetectionResult import org.jetbrains.kotlinx.dl.api.inference.posedetection.PoseLandmark @@ -30,8 +32,8 @@ public abstract class MultiPoseDetectionModelBase : OnnxHighLevelModel> - override fun convert(output: Map): MultiPoseDetectionResult { - val rawPoseLandMarks = (output[outputName] as Array>)[0] + override fun convert(output: OrtSession.Result): MultiPoseDetectionResult { + val rawPoseLandMarks = output.get2DFloatArray(outputName) val poses = rawPoseLandMarks.map { floats -> val foundPoseLandmarks = mutableListOf() diff --git a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/posedetection/SinglePoseDetectionModelBase.kt b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/posedetection/SinglePoseDetectionModelBase.kt index 57c54fc2a..094b0a9cc 100644 --- a/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/posedetection/SinglePoseDetectionModelBase.kt +++ b/onnx/src/commonMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/posedetection/SinglePoseDetectionModelBase.kt @@ -5,7 +5,9 @@ package org.jetbrains.kotlinx.dl.api.inference.onnx.posedetection +import ai.onnxruntime.OrtSession import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxHighLevelModel +import org.jetbrains.kotlinx.dl.api.inference.onnx.OrtSessionResultConversions.get2DFloatArray import org.jetbrains.kotlinx.dl.api.inference.posedetection.DetectedPose import org.jetbrains.kotlinx.dl.api.inference.posedetection.PoseEdge import org.jetbrains.kotlinx.dl.api.inference.posedetection.PoseLandmark @@ -33,8 +35,8 @@ public abstract class SinglePoseDetectionModelBase : OnnxHighLevelModel> = edgeKeyPointsPairs - override fun convert(output: Map): DetectedPose { - val rawPoseLandMarks = (output[outputName] as Array>>)[0][0] + override fun convert(output: OrtSession.Result): DetectedPose { + val rawPoseLandMarks = output.get2DFloatArray(outputName) val foundPoseLandmarks = mutableListOf() for (i in rawPoseLandMarks.indices) { diff --git a/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/dataset/preprocessor/ONNXModelPreprocessor.kt b/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/dataset/preprocessor/ONNXModelPreprocessor.kt index 0cb1a475c..a5c8668af 100644 --- a/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/dataset/preprocessor/ONNXModelPreprocessor.kt +++ b/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/dataset/preprocessor/ONNXModelPreprocessor.kt @@ -5,10 +5,11 @@ package org.jetbrains.kotlinx.dl.api.dataset.preprocessor -import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxInferenceModel +import org.jetbrains.kotlinx.dl.api.inference.onnx.OrtSessionResultConversions.getFloatArrayWithShape import org.jetbrains.kotlinx.dl.dataset.preprocessing.Operation import org.jetbrains.kotlinx.dl.dataset.preprocessing.PreprocessingPipeline +import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape /** * Applies the given [onnxModel] as a preprocessing stage. @@ -19,8 +20,10 @@ import org.jetbrains.kotlinx.dl.dataset.preprocessing.PreprocessingPipeline public class ONNXModelPreprocessor(public var onnxModel: OnnxInferenceModel?, public var outputIndex: Int = 0) : Operation, Pair> { override fun apply(input: Pair): Pair { - val (prediction, rawShape) = onnxModel!!.predictRawWithShapes(input.first)[outputIndex] - return prediction.array() to TensorShape(rawShape) + val (prediction, rawShape) = onnxModel!!.predictRaw(input.first) { output -> + return@predictRaw output.getFloatArrayWithShape(outputIndex) + } + return prediction to TensorShape(rawShape) } override fun getOutputShape(inputShape: TensorShape): TensorShape { diff --git a/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDObjectDetectionModel.kt b/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDObjectDetectionModel.kt index c134ed89c..0733bed40 100644 --- a/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDObjectDetectionModel.kt +++ b/onnx/src/jvmMain/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/objectdetection/SSDObjectDetectionModel.kt @@ -5,10 +5,14 @@ package org.jetbrains.kotlinx.dl.api.inference.onnx.objectdetection +import ai.onnxruntime.OrtSession import org.jetbrains.kotlinx.dl.api.inference.InferenceModel import org.jetbrains.kotlinx.dl.api.inference.objectdetection.DetectedObject import org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxInferenceModel +import org.jetbrains.kotlinx.dl.api.inference.onnx.OrtSessionResultConversions.get2DFloatArray +import org.jetbrains.kotlinx.dl.api.inference.onnx.OrtSessionResultConversions.getFloatArray +import org.jetbrains.kotlinx.dl.api.inference.onnx.OrtSessionResultConversions.getLongArray import org.jetbrains.kotlinx.dl.dataset.Coco import org.jetbrains.kotlinx.dl.dataset.image.ColorMode import org.jetbrains.kotlinx.dl.dataset.image.ImageConverter @@ -80,10 +84,10 @@ public class SSDObjectDetectionModel(override val internalModel: OnnxInferenceMo } // TODO remove code duplication due to different type of class labels array - override fun convert(output: Map): List { - val boxes = (output[metadata.outputBoxesName] as Array>)[0] - val classIndices = (output[metadata.outputClassesName] as Array)[0] - val probabilities = (output[metadata.outputScoresName] as Array)[0] + override fun convert(output: OrtSession.Result): List { + val boxes = output.get2DFloatArray(metadata.outputBoxesName) + val classIndices = output.getLongArray(metadata.outputClassesName) + val probabilities = output.getFloatArray(metadata.outputScoresName) val numberOfFoundObjects = boxes.size val foundObjects = mutableListOf()