Skip to content

Commit

Permalink
Add OpenVINO backend
Browse files Browse the repository at this point in the history
Load model from ONNX representation (#3)

CRF model with OpenVINO

LogZ cpu (#4)

Fix TBB dependency

OpenVINO refactoring

Use OpenVINO 2021.2

Add tests

Use official OpenVINO package (#7)

Remote tests
  • Loading branch information
dkurt committed Apr 22, 2021
1 parent 8a68957 commit 86b437d
Show file tree
Hide file tree
Showing 11 changed files with 295 additions and 14 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ If a reference is provided in either `.fasta` or `.mmi` format then bonito will
$ bonito basecaller dna_r9.4.1 --reference reference.mmi /data/reads > basecalls.sam
```

To optimize inference on CPU with Intel OpenVINO:

```bash
$ bonito basecaller dna_r9.4.1 --reference reference.mmi --use_openvino --device=cpu /data/reads > basecalls.sam
```

## Developer Quickstart

```bash
Expand Down
3 changes: 2 additions & 1 deletion bonito/cli/basecaller.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def main(args):
exit(1)

sys.stderr.write("> loading model\n")
model = load_model(args.model_directory, args.device, weights=int(args.weights))
model = load_model(args.model_directory, args.device, weights=int(args.weights), use_openvino=args.use_openvino)

if args.reference:
sys.stderr.write("> loading reference\n")
Expand Down Expand Up @@ -96,6 +96,7 @@ def argparser():
parser.add_argument("--save-ctc", action="store_true", default=False)
parser.add_argument("--revcomp", action="store_true", default=False)
parser.add_argument("--recursive", action="store_true", default=False)
parser.add_argument("--use_openvino", action="store_true", default=False)
parser.add_argument("--ctc-min-coverage", default=0.9, type=float)
parser.add_argument("--ctc-min-accuracy", default=0.9, type=float)
parser.add_argument("--batchsize", default=32, type=int)
Expand Down
3 changes: 2 additions & 1 deletion bonito/cli/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def main(args):
seqs = []

print("* loading model", w)
model = load_model(args.model_directory, args.device, weights=w)
model = load_model(args.model_directory, args.device, weights=w, use_openvino=args.use_openvino)

print("* calling")
t0 = time.perf_counter()
Expand Down Expand Up @@ -101,4 +101,5 @@ def argparser():
parser.add_argument("--beamsize", default=5, type=int)
parser.add_argument("--poa", action="store_true", default=False)
parser.add_argument("--min-coverage", default=0.5, type=float)
parser.add_argument("--use_openvino", action="store_true", default=False)
return parser
8 changes: 7 additions & 1 deletion bonito/crf/basecall.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def compute_scores(model, batch, reverse=False):
"""
with torch.no_grad():
device = next(model.parameters()).device
dtype = torch.float16 if half_supported() else torch.float32
dtype = torch.float16 if device != torch.device('cpu') and half_supported() else torch.float32
scores = model(batch.to(dtype).to(device))
if reverse: scores = model.seqdist.reverse_complement(scores)
betas = model.seqdist.backward_scores(scores.to(torch.float32))
Expand Down Expand Up @@ -63,6 +63,12 @@ def transfer(x):
"""
Device to host transfer using pinned memory.
"""
if not torch.cuda.is_available():
return {
k: torch.empty(v.shape, pin_memory=False, dtype=v.dtype).copy_(v).numpy()
for k, v in x.items()
}

torch.cuda.synchronize()
with torch.cuda.stream(torch.cuda.Stream()):
return {
Expand Down
67 changes: 63 additions & 4 deletions bonito/crf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import numpy as np
from bonito.nn import Module, Convolution, LinearCRFEncoder, Serial, Permute, layers, from_dict

import seqdist.sparse
from seqdist.ctc_simple import logZ_cupy, viterbi_alignments
if torch.cuda.is_available():
import seqdist.sparse
from seqdist.ctc_simple import logZ_cupy, viterbi_alignments
from seqdist.core import SequenceDist, Max, Log, semiring


Expand All @@ -21,6 +22,58 @@ def get_stride(m):
return 1


def logZ_fwd_cpu(Ms, idx, v0, vT, S):
T, N, C, NZ = Ms.shape
Ms_grad = torch.zeros(T, N, C, NZ)

a = v0
for t in range(T):
s = S.mul(a[:, idx], Ms[t])
a = S.sum(s, -1)
Ms_grad[t] = s
return S.sum(a + vT, dim=1), Ms_grad


def logZ_bwd_cpu(Ms, idx, vT, S, K=1):
assert(K == 1)
T, N, C, NZ = Ms.shape
Ms = Ms.reshape(T, N, -1)
idx_T = idx.flatten().argsort().to(dtype=torch.long).reshape(C, NZ)

betas = torch.ones(T + 1, N, C)

a = vT
betas[T] = a
for t in reversed(range(T)):
s = S.mul(a[:, idx_T // NZ], Ms[t, :, idx_T])
a = S.sum(s, -1)
betas[t] = a
return betas


class _LogZ(torch.autograd.Function):
@staticmethod
def forward(ctx, Ms, idx, v0, vT, S:semiring):
idx = idx.to(dtype=torch.long, device=Ms.device)
logZ, Ms_grad = logZ_fwd_cpu(Ms, idx, v0, vT, S)
ctx.save_for_backward(Ms_grad, Ms, idx, vT)
ctx.semiring = S
return logZ

@staticmethod
def backward(ctx, grad):
Ms_grad, Ms, idx, vT = ctx.saved_tensors
S = ctx.semiring
T, N, C, NZ = Ms.shape
betas = logZ_bwd_cpu(Ms, idx, vT, S)
Ms_grad = S.mul(Ms_grad, betas[1:,:,:,None])
Ms_grad = S.dsum(Ms_grad.reshape(T, N, -1), dim=2).reshape(T, N, C, NZ)
return grad[None, :, None, None] * Ms_grad, None, None, None, None, None

def sparse_logZ(Ms, idx, v0, vT, S:semiring=Log):
return _LogZ.apply(Ms, idx, v0, vT, S)


class CTC_CRF(SequenceDist):

def __init__(self, state_len, alphabet):
Expand All @@ -43,7 +96,10 @@ def logZ(self, scores, S:semiring=Log):
Ms = scores.reshape(T, N, -1, len(self.alphabet))
alpha_0 = Ms.new_full((N, self.n_base**(self.state_len)), S.one)
beta_T = Ms.new_full((N, self.n_base**(self.state_len)), S.one)
return seqdist.sparse.logZ(Ms, self.idx, alpha_0, beta_T, S)
if not Ms.device.index is None:
return seqdist.sparse.logZ(Ms, self.idx, alpha_0, beta_T, S)
else:
return sparse_logZ(Ms, self.idx, alpha_0, beta_T, S)

def normalise(self, scores):
return (scores - self.logZ(scores)[:, None] / len(scores))
Expand All @@ -58,7 +114,10 @@ def backward_scores(self, scores, S: semiring=Log):
T, N, _ = scores.shape
Ms = scores.reshape(T, N, -1, self.n_base + 1)
beta_T = Ms.new_full((N, self.n_base**(self.state_len)), S.one)
return seqdist.sparse.bwd_scores_cupy(Ms, self.idx, beta_T, S, K=1)
if not Ms.device.index is None:
return seqdist.sparse.bwd_scores_cupy(Ms, self.idx, beta_T, S, K=1)
else:
return logZ_bwd_cpu(Ms, self.idx, beta_T, S, K=1)

def reverse_complement(self, scores):
T, N, C = scores.shape
Expand Down
3 changes: 2 additions & 1 deletion bonito/ctc/basecall.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def compute_scores(model, batch):
"""
with torch.no_grad():
device = next(model.parameters()).device
chunks = batch.to(torch.half).to(device)
chunks = batch.to(torch.half) if half_supported() else batch
chunks = chunks.to(device)
probs = permute(model(chunks), 'TNC', 'NTC')
return probs.cpu().to(torch.float32)

Expand Down
2 changes: 1 addition & 1 deletion bonito/ctc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,4 +194,4 @@ def __init__(self, features, classes):
)

def forward(self, x):
return log_softmax(self.layers(x), dim=2)
return log_softmax(self.layers(x), dim=-1)
31 changes: 31 additions & 0 deletions bonito/openvino/loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch.nn as nn


def convert_to_2d(model):
for name, l in model.named_children():
layer_type = l.__class__.__name__
if layer_type == 'Conv1d':
new_layer = nn.Conv2d(l.in_channels, l.out_channels,
(1, l.kernel_size[0]), (1, l.stride[0]),
(0, l.padding[0]), (1, l.dilation[0]),
l.groups, False if l.bias is None else True, l.padding_mode)
params = l.state_dict()
params['weight'] = params['weight'].unsqueeze(2)
new_layer.load_state_dict(params)
setattr(model, name, new_layer)
elif layer_type == 'BatchNorm1d':
new_layer = nn.BatchNorm2d(l.num_features, l.eps)
new_layer.load_state_dict(l.state_dict())
new_layer.eval()
setattr(model, name, new_layer)
elif layer_type == 'Permute':
dims_2d = []
# 1D to 2D: i.e. (2, 0, 1) -> (2, 3, 0, 1)
for d in l.dims:
assert(d <= 2)
dims_2d.append(d)
if d == 2:
dims_2d.append(3)
l.dims = dims_2d
else:
convert_to_2d(l)
170 changes: 170 additions & 0 deletions bonito/openvino/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import os
import io
import numpy as np
import torch

from bonito.nn import Swish

try:
from openvino.inference_engine import IECore, StatusCode
from .loader import convert_to_2d
except ImportError:
pass


def load_openvino_model(model, dirname):
package = model.config['model']['package']
if package == 'bonito.ctc':
return CTCModel(model, dirname)
elif package == 'bonito.crf':
return CRFModel(model, dirname)
else:
raise Exception('Unknown model configuration: ' + package)


class OpenVINOModel:

def __init__(self, model, dirname):
self.model = model
self.alphabet = model.alphabet
self.parameters = model.parameters
self.stride = model.stride
self.net = None
self.exec_net = None
self.dirname = dirname
self.ie = IECore()


def eval(self):
pass


def half(self):
return self


def to(self, device):
self.device = str(device).upper()

"""
Call this method once to initialize executable network
"""
def init_model(self, model, inp_shape):
# First, we try to check if there is IR on disk. If not - load model in runtime
xml_path, bin_path = [os.path.join(self.dirname, 'model') + ext for ext in ['.xml', '.bin']]
if os.path.exists(xml_path) and os.path.exists(bin_path):
self.net = self.ie.read_network(xml_path, bin_path)
else:
# There is an issue with Swish at export step so we temporarly use default implementation
origin_swish_forward = Swish.forward
def swish_fake_forward(self, x):
return x * torch.sigmoid(x)
Swish.forward = swish_fake_forward

# Convert model to ONNX buffer
buf = io.BytesIO()
inp = torch.randn(inp_shape)
torch.onnx.export(model, inp, buf, input_names=['input'], output_names=['output'],
opset_version=11)
Swish.forward = origin_swish_forward

# Import network from memory buffer
self.net = self.ie.read_network(buf.getvalue(), b'', init_from_buffer=True)

# Load model to device
config = {}
if self.device == 'CPU':
config={'CPU_THROUGHPUT_STREAMS': 'CPU_THROUGHPUT_AUTO'}
self.exec_net = self.ie.load_network(self.net, self.device,
config=config, num_requests=0)


def process(self, data):
data = data.float()
batch_size = data.shape[0]
inp_shape = list(data.shape)
inp_shape[0] = 1 # We will run the batch asynchronously

# List that maps infer requests to index of processed chunk from batch.
# -1 means that request has not been started yet.
infer_request_input_id = [-1] * len(self.exec_net.requests)
out_shape = self.net.outputs['output'].shape
# CTC network produces 1xWxNxC
output = np.zeros([out_shape[-3], batch_size, out_shape[-1]], dtype=np.float32)

for inp_id in range(batch_size):
# Get idle infer request
infer_request_id = self.exec_net.get_idle_request_id()
if infer_request_id < 0:
status = self.exec_net.wait(num_requests=1)
if status != StatusCode.OK:
raise Exception("Wait for idle request failed!")
infer_request_id = self.exec_net.get_idle_request_id()
if infer_request_id < 0:
raise Exception("Invalid request id!")

out_id = infer_request_input_id[infer_request_id]
request = self.exec_net.requests[infer_request_id]

# Copy output prediction
if out_id != -1:
output[:,out_id:out_id+1] = request.output_blobs['output'].buffer

# Start this request on new data
infer_request_input_id[infer_request_id] = inp_id
request.async_infer({'input': data[inp_id]})
inp_id += 1

# Wait for the rest of requests
status = self.exec_net.wait()
if status != StatusCode.OK:
raise Exception("Wait for idle request failed!")
for infer_request_id, out_id in enumerate(infer_request_input_id):
if out_id == -1:
continue
request = self.exec_net.requests[infer_request_id]
output[:,out_id:out_id+1] = request.output_blobs['output'].buffer

return torch.tensor(output)


class CTCModel(OpenVINOModel):

def __init__(self, model, dirname):
super().__init__(model, dirname)


def __call__(self, data):
data = data.unsqueeze(2) # 1D->2D
if self.exec_net is None:
convert_to_2d(self.model)
self.init_model(self.model, [1, 1, 1, data.shape[-1]])

return self.process(data)


def decode(self, x, beamsize=5, threshold=1e-3, qscores=False, return_path=False):
return self.model.decode(x, beamsize=beamsize, threshold=threshold,
qscores=qscores, return_path=return_path)


class CRFModel(OpenVINOModel):

def __init__(self, model, dirname):
super().__init__(model, dirname)
self.seqdist = model.seqdist


def __call__(self, data):
if self.exec_net is None:
self.init_model(self.model.encoder, [1, 1, data.shape[-1]])

return self.process(data)


def decode(self, x):
return self.model.decode(x)


def decode_batch(self, x):
return self.model.decode_batch(x)
Loading

0 comments on commit 86b437d

Please sign in to comment.