Skip to content

Commit

Permalink
dnsmos
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Dec 5, 2024
1 parent aec0600 commit 2ceaa77
Showing 1 changed file with 18 additions and 22 deletions.
40 changes: 18 additions & 22 deletions src/torchmetrics/functional/audio/dnsmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@

if _LIBROSA_AVAILABLE and _ONNXRUNTIME_AVAILABLE and _REQUESTS_AVAILABLE:
import librosa
import onnxruntime as ort
import requests
from onnxruntime import InferenceSession
from onnxruntime import InferenceSession, OrtValue, SessionOptions, get_available_providers
else:
librosa, ort, requests = None, None, None # type:ignore

Expand Down Expand Up @@ -97,25 +96,22 @@ def _load_session(
if not os.path.exists(path):
_prepare_dnsmos(DNSMOS_DIR)

opts = ort.SessionOptions()
opts = SessionOptions()
if num_threads is not None:
opts.inter_op_num_threads = num_threads
opts.intra_op_num_threads = num_threads

if device.type == "cpu":
infs = InferenceSession(path, providers=["CPUExecutionProvider"], sess_options=opts)
elif "CUDAExecutionProvider" in ort.get_available_providers(): # win or linux with cuda
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
provider_options = [{"device_id": device.index}, {}]
infs = InferenceSession(path, providers=providers, provider_options=provider_options, sess_options=opts)
elif "CoreMLExecutionProvider" in ort.get_available_providers(): # macos with coreml
providers = ["CoreMLExecutionProvider", "CPUExecutionProvider"]
provider_options = [{"device_id": device.index}, {}]
infs = InferenceSession(path, providers=providers, provider_options=provider_options, sess_options=opts)
else:
infs = InferenceSession(path, providers=["CPUExecutionProvider"], sess_options=opts)

return infs
if device.type != "cpu":
if "CUDAExecutionProvider" in get_available_providers(): # win or linux with cuda
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
provider_options = [{"device_id": device.index}, {}]
return InferenceSession(path, providers=providers, provider_options=provider_options, sess_options=opts)
if "CoreMLExecutionProvider" in get_available_providers(): # macos with coreml
providers = ["CoreMLExecutionProvider", "CPUExecutionProvider"]
provider_options = [{"device_id": device.index}, {}]
return InferenceSession(path, providers=providers, provider_options=provider_options, sess_options=opts)
raise NotImplementedError("No GPU or CoreML provider found, reverting to CPU.")
return InferenceSession(path, providers=["CPUExecutionProvider"], sess_options=opts)


def _audio_melspec(
Expand Down Expand Up @@ -256,13 +252,13 @@ def deep_noise_suppression_mean_opinion_score(
input_features = np.array(audio_seg).astype("float32")
p808_input_features = np.array(_audio_melspec(audio=audio_seg[..., :-160])).astype("float32")

if device.type != "cpu" and (
"CUDAExecutionProvider" in ort.get_available_providers()
or "CoreMLExecutionProvider" in ort.get_available_providers()
onnx_available_providers = get_available_providers()
if device.type != "cpu" and any(
p in onnx_available_providers for p in ["CUDAExecutionProvider", "CoreMLExecutionProvider"]
):
try:
input_features = ort.OrtValue.ortvalue_from_numpy(input_features, device.type, device.index)
p808_input_features = ort.OrtValue.ortvalue_from_numpy(p808_input_features, device.type, device.index)
input_features = OrtValue.ortvalue_from_numpy(input_features, device.type, device.index)
p808_input_features = OrtValue.ortvalue_from_numpy(p808_input_features, device.type, device.index)
except Exception as e:
rank_zero_warn(f"Failed to use GPU for DNSMOS, reverting to CPU. Error: {e}")

Expand Down

0 comments on commit 2ceaa77

Please sign in to comment.