Skip to content

Commit

Permalink
[iOS Object Detection] Move create ORTSession to init() (#19)
Browse files Browse the repository at this point in the history
* wip

* move env to class variable

* fix typo in readmes

Co-authored-by: rachguo <[email protected]>
  • Loading branch information
YUNQIUGUO and rachguo authored Aug 5, 2021
1 parent aadceef commit e869741
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 29 deletions.
4 changes: 2 additions & 2 deletions mobile/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ The example app uses speech recognition to transcribe speech from audio recorded

- [iOS Speech Recognition](examples/speech_recognition/ios)

## Object Detection
### Object Detection

The example app uses object detection which is able to continuously detect the objects in the frames seen by your iOS device's back camera and display the detected object bounding boxes, detected class and corresponding inference confidence on the screen.

- [iOS Object Detector](examples/object_detection/ios)
- [iOS Object Detection](examples/object_detection/ios)
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,14 @@ class ModelHandler: NSObject {

private var labels: [String] = []

init(threadCount: Int32 = 1) {
self.threadCount = threadCount

super.init()
}

// This method preprocesses the image, runs the ort inferencesession and returns the inference result
func runModel(onFrame pixelBuffer: CVPixelBuffer, modelFileInfo: FileInfo, labelsFileInfo: FileInfo)
throws -> Result?
{
/// ORT inference session and environment object for performing inference on the given ssd model
private var session: ORTSession
private var env: ORTEnv

// MARK: - Initialization of ModelHandler
init?(modelFileInfo: FileInfo, labelsFileInfo: FileInfo, threadCount: Int32 = 1) {
let modelFilename = modelFileInfo.name

labels = loadLabels(fileInfo: labelsFileInfo)

guard let modelPath = Bundle.main.path(
forResource: modelFilename,
ofType: modelFileInfo.extension
Expand All @@ -94,6 +88,27 @@ class ModelHandler: NSObject {
return nil
}

self.threadCount = threadCount
do {
// Start the ORT inference environment and specify the options for session
env = try ORTEnv(loggingLevel: ORTLoggingLevel.verbose)
let options = try ORTSessionOptions()
try options.setLogSeverityLevel(ORTLoggingLevel.verbose)
try options.setIntraOpNumThreads(threadCount)
// Create the ORTSession
session = try ORTSession(env: env, modelPath: modelPath, sessionOptions: options)
} catch {
print("Failed to create ORTSession.")
return nil
}

super.init()

labels = loadLabels(fileInfo: labelsFileInfo)
}

// This method preprocesses the image, runs the ort inferencesession and returns the inference result
func runModel(onFrame pixelBuffer: CVPixelBuffer) throws -> Result? {
let sourcePixelFormat = CVPixelBufferGetPixelFormatType(pixelBuffer)
assert(sourcePixelFormat == kCVPixelFormatType_32ARGB ||
sourcePixelFormat == kCVPixelFormatType_32BGRA ||
Expand All @@ -111,14 +126,6 @@ class ModelHandler: NSObject {
return nil
}

// Start the ORT inference environment
let env = try ORTEnv(loggingLevel: ORTLoggingLevel.warning)
let options = try ORTSessionOptions()
try options.setLogSeverityLevel(ORTLoggingLevel.verbose)
try options.setIntraOpNumThreads(threadCount)

let session = try ORTSession(env: env, modelPath: modelPath, sessionOptions: options)

let interval: TimeInterval

let inputName = "normalized_input_image_tensor"
Expand All @@ -135,6 +142,7 @@ class ModelHandler: NSObject {
inputHeight as NSNumber,
inputWidth as NSNumber,
inputChannels as NSNumber]

let inputTensor = try ORTValue(tensorData: NSMutableData(data: rgbData),
elementType: ORTTensorElementDataType.uInt8,
shape: inputShape)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
</subviews>
<color key="backgroundColor" white="1" alpha="0.0" colorSpace="custom" customColorSpace="genericGamma22GrayColorSpace"/>
</view>
<view contentMode="scaleToFill" fixedFrame="YES" translatesAutoresizingMaskIntoConstraints="NO" id="ro1-YL-L1d" customClass="CurvedView" customModule="ORTObjectDetection" customModuleProvider="target">
<view contentMode="scaleToFill" fixedFrame="YES" translatesAutoresizingMaskIntoConstraints="NO" id="ro1-YL-L1d">
<rect key="frame" x="0.0" y="579" width="390" height="231"/>
<autoresizingMask key="autoresizingMask" flexibleMaxX="YES" flexibleMaxY="YES"/>
<subviews>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ class ViewController: UIViewController {
private var inferenceViewController: InferenceViewController?

// Handle all model and data preprocessing and run inference
private var modelHandler: ModelHandler? = ModelHandler()
private var modelHandler: ModelHandler? = ModelHandler(
modelFileInfo: (name: "ssd_mobilenet_v1.all", extension: "ort"),
labelsFileInfo: (name: "labelmap", extension: "txt"))

// MARK: View Controller Life Cycle

Expand Down Expand Up @@ -89,7 +91,9 @@ class ViewController: UIViewController {
extension ViewController: InferenceViewControllerDelegate {
func didChangeThreadCount(to count: Int32) {
if modelHandler?.threadCount == count { return }
modelHandler = ModelHandler(threadCount: count)
modelHandler = ModelHandler(modelFileInfo: (name: "ssd_mobilenet_v1.all", extension: "ort"),
labelsFileInfo: (name: "labelmap", extension: "txt"),
threadCount: count)
}
}

Expand Down Expand Up @@ -134,9 +138,7 @@ extension ViewController: CameraManagerDelegate {
else { return }
previousInferenceTimeMs = currentTimeMs

result = try! modelHandler?.runModel(onFrame: pixelBuffer,
modelFileInfo: (name: "ssd_mobilenet_v1.all", extension: "ort"),
labelsFileInfo: (name: "labelmap", extension: "txt"))
result = try! self.modelHandler?.runModel(onFrame: pixelBuffer)

guard let displayResult = result else {
return
Expand Down
3 changes: 2 additions & 1 deletion mobile/examples/object_detection/ios/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ The original `ssd_mobilenet_v1.tflite` model can be downloaded [here](https://ww
2. In terminal, run `pod install` under `<ONNXRuntime-inference-example-root>/mobile/examples/object_detections/ios/` to generate the workspace file.
- At the end of this step, you should get a file called `ORTObjectDetection.xcworkspace`.

3. In terminal, run download script `./download.sh` or `bash download.sh` under `<ONNXRuntime-inference-example-root>/mobile/examples/object_detections/ios/ORTObjectDetection/`. The script will download an original tflite model along with the model metadata `labelmap.txt` and convert it to onnx model and then further convert it to ort format model (this is the format can be executed on mobile applications).
3. In terminal, run the script for preparing model.
- Run `./prepare_model.sh` or `bash prepare_model.sh` under `<ONNXRuntime-inference-example-root>/mobile/examples/object_detections/ios/ORTObjectDetection/`. The script will download an original tflite model along with the model metadata `labelmap.txt` and convert it to onnx model and then further convert it to ort format model (this is the format can be executed on mobile applications).
- At the end of this step, you should get a directory `ModelsAndData` which contains the ort format model `ssd_mobilenet_v1.all.ort` and model label data file `labelmap.txt`.

Note: The model and data files generated might need to be copied to app bundle. i.e. In Xcode, `Build phases-Expand Copy Bundle Resources-Click '+' and select model file name "ssd_mobilenet_v1.all.ort" and select label data file "labelmap.txt"`.
Expand Down

0 comments on commit e869741

Please sign in to comment.