Skip to content

Commit

Permalink
Make ONNX runtime optional (#215)
Browse files Browse the repository at this point in the history
  • Loading branch information
juanmc2005 authored Nov 18, 2023
1 parent 8cad376 commit 8e9f74c
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 12 deletions.
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,3 @@ optuna>=2.10
websocket-server>=0.6.4
websocket-client>=0.58.0
rich>=12.5.1
onnxruntime-gpu>=1.16.1
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ install_requires=
websocket-server>=0.6.4
websocket-client>=0.58.0
rich>=12.5.1
onnxruntime-gpu>=1.16.1

[options.packages.find]
where=src
Expand Down
26 changes: 16 additions & 10 deletions src/diart/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Optional, Text, Union, Callable, List

import numpy as np
import onnxruntime
import torch
import torch.nn as nn
from requests import HTTPError
Expand All @@ -15,9 +14,16 @@
from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
from pyannote.audio.utils.powerset import Powerset

_has_pyannote = True
IS_PYANNOTE_AVAILABLE = True
except ImportError:
_has_pyannote = False
IS_PYANNOTE_AVAILABLE = False

try:
import onnxruntime as ort

IS_ONNX_AVAILABLE = True
except ImportError:
IS_ONNX_AVAILABLE = False


class PowersetAdapter(nn.Module):
Expand Down Expand Up @@ -88,11 +94,9 @@ def execution_provider(self) -> str:
return f"{device}ExecutionProvider"

def recreate_session(self):
options = onnxruntime.SessionOptions()
options.graph_optimization_level = (
onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
)
self.session = onnxruntime.InferenceSession(
options = ort.SessionOptions()
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
self.session = ort.InferenceSession(
self.path,
sess_options=options,
providers=[self.execution_provider],
Expand Down Expand Up @@ -168,7 +172,7 @@ def from_pyannote(
-------
wrapper: SegmentationModel
"""
assert _has_pyannote, "No pyannote.audio installation found"
assert IS_PYANNOTE_AVAILABLE, "No pyannote.audio installation found"
return SegmentationModel(PyannoteLoader(model, use_hf_token))

@staticmethod
Expand All @@ -177,6 +181,7 @@ def from_onnx(
input_name: str = "waveform",
output_name: str = "segmentation",
) -> "SegmentationModel":
assert IS_ONNX_AVAILABLE, "No ONNX installation found"
return SegmentationModel(ONNXLoader(model_path, [input_name], output_name))

@staticmethod
Expand Down Expand Up @@ -224,7 +229,7 @@ def from_pyannote(
-------
wrapper: EmbeddingModel
"""
assert _has_pyannote, "No pyannote.audio installation found"
assert IS_PYANNOTE_AVAILABLE, "No pyannote.audio installation found"
loader = PyannoteLoader(model, use_hf_token)
return EmbeddingModel(loader)

Expand All @@ -234,6 +239,7 @@ def from_onnx(
input_names: List[str] | None = None,
output_name: str = "embedding",
) -> "EmbeddingModel":
assert IS_ONNX_AVAILABLE, "No ONNX installation found"
input_names = input_names or ["waveform", "weights"]
loader = ONNXLoader(model_path, input_names, output_name)
return EmbeddingModel(loader)
Expand Down

0 comments on commit 8e9f74c

Please sign in to comment.