From e8697411a283e56e7981ea54a5eb72e3178a3e1d Mon Sep 17 00:00:00 2001 From: Rachel Guo <35738743+YUNQIUGUO@users.noreply.github.com> Date: Thu, 5 Aug 2021 15:36:29 -0700 Subject: [PATCH] [iOS Object Detection] Move `create ORTSession` to init() (#19) * wip * move env to class variable * fix typo in readmes Co-authored-by: rachguo --- mobile/README.md | 4 +- .../ios/ORTObjectDetection/ModelHandler.swift | 48 +++++++++++-------- .../Storyboards/Main.storyboard | 2 +- .../ORTObjectDetection/ViewController.swift | 12 +++-- .../examples/object_detection/ios/README.md | 3 +- 5 files changed, 40 insertions(+), 29 deletions(-) diff --git a/mobile/README.md b/mobile/README.md index 6540f1721ba28..acbb158c153bb 100644 --- a/mobile/README.md +++ b/mobile/README.md @@ -42,8 +42,8 @@ The example app uses speech recognition to transcribe speech from audio recorded - [iOS Speech Recognition](examples/speech_recognition/ios) -## Object Detection +### Object Detection The example app uses object detection which is able to continuously detect the objects in the frames seen by your iOS device's back camera and display the detected object bounding boxes, detected class and corresponding inference confidence on the screen. -- [iOS Object Detector](examples/object_detection/ios) +- [iOS Object Detection](examples/object_detection/ios) diff --git a/mobile/examples/object_detection/ios/ORTObjectDetection/ModelHandler.swift b/mobile/examples/object_detection/ios/ORTObjectDetection/ModelHandler.swift index 592f137aeb12a..10657d25a8972 100644 --- a/mobile/examples/object_detection/ios/ORTObjectDetection/ModelHandler.swift +++ b/mobile/examples/object_detection/ios/ORTObjectDetection/ModelHandler.swift @@ -72,20 +72,14 @@ class ModelHandler: NSObject { private var labels: [String] = [] - init(threadCount: Int32 = 1) { - self.threadCount = threadCount - - super.init() - } - - // This method preprocesses the image, runs the ort inferencesession and returns the inference result - func runModel(onFrame pixelBuffer: CVPixelBuffer, modelFileInfo: FileInfo, labelsFileInfo: FileInfo) - throws -> Result? - { + /// ORT inference session and environment object for performing inference on the given ssd model + private var session: ORTSession + private var env: ORTEnv + + // MARK: - Initialization of ModelHandler + init?(modelFileInfo: FileInfo, labelsFileInfo: FileInfo, threadCount: Int32 = 1) { let modelFilename = modelFileInfo.name - labels = loadLabels(fileInfo: labelsFileInfo) - guard let modelPath = Bundle.main.path( forResource: modelFilename, ofType: modelFileInfo.extension @@ -94,6 +88,27 @@ class ModelHandler: NSObject { return nil } + self.threadCount = threadCount + do { + // Start the ORT inference environment and specify the options for session + env = try ORTEnv(loggingLevel: ORTLoggingLevel.verbose) + let options = try ORTSessionOptions() + try options.setLogSeverityLevel(ORTLoggingLevel.verbose) + try options.setIntraOpNumThreads(threadCount) + // Create the ORTSession + session = try ORTSession(env: env, modelPath: modelPath, sessionOptions: options) + } catch { + print("Failed to create ORTSession.") + return nil + } + + super.init() + + labels = loadLabels(fileInfo: labelsFileInfo) + } + + // This method preprocesses the image, runs the ort inferencesession and returns the inference result + func runModel(onFrame pixelBuffer: CVPixelBuffer) throws -> Result? { let sourcePixelFormat = CVPixelBufferGetPixelFormatType(pixelBuffer) assert(sourcePixelFormat == kCVPixelFormatType_32ARGB || sourcePixelFormat == kCVPixelFormatType_32BGRA || @@ -111,14 +126,6 @@ class ModelHandler: NSObject { return nil } - // Start the ORT inference environment - let env = try ORTEnv(loggingLevel: ORTLoggingLevel.warning) - let options = try ORTSessionOptions() - try options.setLogSeverityLevel(ORTLoggingLevel.verbose) - try options.setIntraOpNumThreads(threadCount) - - let session = try ORTSession(env: env, modelPath: modelPath, sessionOptions: options) - let interval: TimeInterval let inputName = "normalized_input_image_tensor" @@ -135,6 +142,7 @@ class ModelHandler: NSObject { inputHeight as NSNumber, inputWidth as NSNumber, inputChannels as NSNumber] + let inputTensor = try ORTValue(tensorData: NSMutableData(data: rgbData), elementType: ORTTensorElementDataType.uInt8, shape: inputShape) diff --git a/mobile/examples/object_detection/ios/ORTObjectDetection/Storyboards/Main.storyboard b/mobile/examples/object_detection/ios/ORTObjectDetection/Storyboards/Main.storyboard index 2852f49fa1f6d..8f7dda8cd3998 100644 --- a/mobile/examples/object_detection/ios/ORTObjectDetection/Storyboards/Main.storyboard +++ b/mobile/examples/object_detection/ios/ORTObjectDetection/Storyboards/Main.storyboard @@ -38,7 +38,7 @@ - + diff --git a/mobile/examples/object_detection/ios/ORTObjectDetection/ViewController.swift b/mobile/examples/object_detection/ios/ORTObjectDetection/ViewController.swift index cc58eef25a3ac..155e1ed829f32 100644 --- a/mobile/examples/object_detection/ios/ORTObjectDetection/ViewController.swift +++ b/mobile/examples/object_detection/ios/ORTObjectDetection/ViewController.swift @@ -36,7 +36,9 @@ class ViewController: UIViewController { private var inferenceViewController: InferenceViewController? // Handle all model and data preprocessing and run inference - private var modelHandler: ModelHandler? = ModelHandler() + private var modelHandler: ModelHandler? = ModelHandler( + modelFileInfo: (name: "ssd_mobilenet_v1.all", extension: "ort"), + labelsFileInfo: (name: "labelmap", extension: "txt")) // MARK: View Controller Life Cycle @@ -89,7 +91,9 @@ class ViewController: UIViewController { extension ViewController: InferenceViewControllerDelegate { func didChangeThreadCount(to count: Int32) { if modelHandler?.threadCount == count { return } - modelHandler = ModelHandler(threadCount: count) + modelHandler = ModelHandler(modelFileInfo: (name: "ssd_mobilenet_v1.all", extension: "ort"), + labelsFileInfo: (name: "labelmap", extension: "txt"), + threadCount: count) } } @@ -134,9 +138,7 @@ extension ViewController: CameraManagerDelegate { else { return } previousInferenceTimeMs = currentTimeMs - result = try! modelHandler?.runModel(onFrame: pixelBuffer, - modelFileInfo: (name: "ssd_mobilenet_v1.all", extension: "ort"), - labelsFileInfo: (name: "labelmap", extension: "txt")) + result = try! self.modelHandler?.runModel(onFrame: pixelBuffer) guard let displayResult = result else { return diff --git a/mobile/examples/object_detection/ios/README.md b/mobile/examples/object_detection/ios/README.md index 9e7d922637d7a..aadeb9f09ebf2 100644 --- a/mobile/examples/object_detection/ios/README.md +++ b/mobile/examples/object_detection/ios/README.md @@ -23,7 +23,8 @@ The original `ssd_mobilenet_v1.tflite` model can be downloaded [here](https://ww 2. In terminal, run `pod install` under `/mobile/examples/object_detections/ios/` to generate the workspace file. - At the end of this step, you should get a file called `ORTObjectDetection.xcworkspace`. -3. In terminal, run download script `./download.sh` or `bash download.sh` under `/mobile/examples/object_detections/ios/ORTObjectDetection/`. The script will download an original tflite model along with the model metadata `labelmap.txt` and convert it to onnx model and then further convert it to ort format model (this is the format can be executed on mobile applications). +3. In terminal, run the script for preparing model. +- Run `./prepare_model.sh` or `bash prepare_model.sh` under `/mobile/examples/object_detections/ios/ORTObjectDetection/`. The script will download an original tflite model along with the model metadata `labelmap.txt` and convert it to onnx model and then further convert it to ort format model (this is the format can be executed on mobile applications). - At the end of this step, you should get a directory `ModelsAndData` which contains the ort format model `ssd_mobilenet_v1.all.ort` and model label data file `labelmap.txt`. Note: The model and data files generated might need to be copied to app bundle. i.e. In Xcode, `Build phases-Expand Copy Bundle Resources-Click '+' and select model file name "ssd_mobilenet_v1.all.ort" and select label data file "labelmap.txt"`.