Skip to content

Commit

Permalink
Revert "Revert "Support Transformers in the Wav2Vec2 Encoder for the …
Browse files Browse the repository at this point in the history
…ASR Inference (OpenNMT#1520)""

This reverts commit 7c60769.
  • Loading branch information
Valentin Berkes committed Nov 7, 2023
1 parent 7c60769 commit 754e9dd
Show file tree
Hide file tree
Showing 16 changed files with 618 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ set(SOURCES
src/layers/common.cc
src/layers/decoder.cc
src/layers/transformer.cc
src/layers/wav2vec2.cc
src/layers/whisper.cc
src/logging.cc
src/models/language_model.cc
Expand All @@ -124,6 +125,7 @@ set(SOURCES
src/models/model_reader.cc
src/models/sequence_to_sequence.cc
src/models/transformer.cc
src/models/wav2vec2.cc
src/models/whisper.cc
src/ops/activation.cc
src/ops/add.cc
Expand Down
47 changes: 47 additions & 0 deletions include/ctranslate2/layers/wav2vec2.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#pragma once

#include "ctranslate2/layers/transformer.h"

namespace ctranslate2 {
namespace layers {

class Wav2Vec2Encoder : public Layer {
public:
Wav2Vec2Encoder(const models::Model& model, const std::string& scope);

void operator()(const StorageView& features, StorageView& output);

DataType output_type() const override {
return _output_norm.output_type();
}

dim_t output_size() const override {
return _output_norm.output_size();
}

dim_t input_size() const {
return 1024;
}

bool is_encoded(const StorageView& features) const {
// Input features shape: [batch_size, input_size, input_time]
// Encoder output shape: [batch_size, input_time // 2, output_size]
//
// input_time is variable so we check that dimension 1 is different than its original value.

return (features.rank() == 3
&& features.dim(2) == output_size()
&& features.dim(1) != input_size());
}

private:
const ops::GELU _gelu;
// wav2vec2.encoder modules except pos_conv_embed due to groups=16 being not supported
//const ops::Transpose _transpose;
const dim_t _num_heads;
const std::vector<std::unique_ptr<const TransformerEncoderLayer>> _layers;
const LayerNorm _output_norm;
};

}
}
72 changes: 72 additions & 0 deletions include/ctranslate2/models/wav2vec2.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#pragma once

//#include "ctranslate2/generation.h"
#include "ctranslate2/layers/wav2vec2.h"
#include "ctranslate2/models/model.h"
#include "ctranslate2/replica_pool.h"

namespace ctranslate2 {
namespace models {

struct Wav2Vec2Options {
// Maximum generation length.
size_t max_length = 448;

// Randomly sample from the top K candidates (set 0 to sample from the full distribution).
size_t sampling_topk = 1;

// Maximum index of the first predicted timestamp.
size_t max_initial_timestamp_index = 50;

// Suppress blank outputs at the beginning of the sampling.
bool suppress_blank = true;

// List of token IDs to suppress.
// -1 will suppress a default set of symbols as defined in the model config.json file.
std::vector<int> suppress_tokens = {-1};
};


class Wav2Vec2Model : public Model {
public:
const Vocabulary& get_vocabulary() const;
size_t current_spec_revision() const override;
bool is_quantizable(const std::string& variable_name) const override;
bool is_linear_weight(const std::string& variable_name) const override;
std::unique_ptr<Model> clone() const override;

bool use_global_int16_scale() const override {
return false;
}

protected:
void initialize(ModelReader& model_reader) override;
private:
std::shared_ptr<const Vocabulary> _vocabulary;
};

class Wav2Vec2Replica : public ModelReplica {
public:
static std::unique_ptr<Wav2Vec2Replica> create_from_model(const Model& model);

Wav2Vec2Replica(const std::shared_ptr<const Wav2Vec2Model>& model);

StorageView encode(StorageView features, const bool to_cpu);

private:
const std::shared_ptr<const Wav2Vec2Model> _model;
const std::unique_ptr<layers::Wav2Vec2Encoder> _encoder;

StorageView maybe_encode(StorageView features);
};

class Wav2Vec2 : public ReplicaPool<Wav2Vec2Replica> {
public:
using ReplicaPool::ReplicaPool;

std::future<StorageView> encode(const StorageView& features, const bool to_cpu);

};

}
}
1 change: 1 addition & 0 deletions python/cpp/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,5 @@ PYBIND11_MODULE(_ext, m)
ctranslate2::python::register_generator(m);
ctranslate2::python::register_encoder(m);
ctranslate2::python::register_whisper(m);
ctranslate2::python::register_wav2vec2(m);
}
1 change: 1 addition & 0 deletions python/cpp/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ namespace ctranslate2 {
void register_translation_stats(py::module& m);
void register_translator(py::module& m);
void register_whisper(py::module& m);
void register_wav2vec2(py::module& m);

}
}
93 changes: 93 additions & 0 deletions python/cpp/wav2vec2.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#include "module.h"

#include <ctranslate2/models/wav2vec2.h>

#include "replica_pool.h"

namespace ctranslate2 {
namespace python {

class Wav2Vec2Wrapper : public ReplicaPoolHelper<models::Wav2Vec2> {
public:
using ReplicaPoolHelper::ReplicaPoolHelper;

StorageView encode(const StorageView& features, const bool to_cpu) {
return _pool->encode(features, to_cpu).get();
}
};


void register_wav2vec2(py::module& m) {
py::class_<Wav2Vec2Wrapper>(
m, "Wav2Vec2",
R"pbdoc(
Implements the Wav2Vec2 speech recognition model published by Facebook.
See Also:
https://github.com/facebookresearch/fairseq/tree/main/examples/wav2vec
)pbdoc")

.def(py::init<const std::string&, const std::string&, const std::variant<int, std::vector<int>>&, const StringOrMap&, size_t, size_t, long, py::object>(),
py::arg("model_path"),
py::arg("device")="cpu",
py::kw_only(),
py::arg("device_index")=0,
py::arg("compute_type")="default",
py::arg("inter_threads")=1,
py::arg("intra_threads")=0,
py::arg("max_queued_batches")=0,
py::arg("files")=py::none(),
R"pbdoc(
Initializes a Wav2Vec2 model from a converted model.
Arguments:
model_path: Path to the CTranslate2 model directory.
device: Device to use (possible values are: cpu, cuda, auto).
device_index: Device IDs where to place this model on.
compute_type: Model computation type or a dictionary mapping a device name
to the computation type (possible values are: default, auto, int8, int8_float32,
int8_float16, int8_bfloat16, int16, float16, bfloat16, float32).
inter_threads: Number of workers to allow executing multiple batches in parallel.
intra_threads: Number of OpenMP threads per worker (0 to use a default value).
max_queued_batches: Maximum numbers of batches in the worker queue (-1 for unlimited,
0 for an automatic value). When the queue is full, future requests will block
until a free slot is available.
files: Load model files from the memory. This argument is a dictionary mapping
file names to file contents as file-like or bytes objects. If this is set,
:obj:`model_path` acts as an identifier for this model.
)pbdoc")

.def_property_readonly("device", &Wav2Vec2Wrapper::device,
"Device this model is running on.")
.def_property_readonly("device_index", &Wav2Vec2Wrapper::device_index,
"List of device IDs where this model is running on.")
.def_property_readonly("compute_type", &Wav2Vec2Wrapper::compute_type,
"Computation type used by the model.")
.def_property_readonly("num_workers", &Wav2Vec2Wrapper::num_replicas,
"Number of model workers backing this instance.")
.def_property_readonly("num_queued_batches", &Wav2Vec2Wrapper::num_queued_batches,
"Number of batches waiting to be processed.")
.def_property_readonly("num_active_batches", &Wav2Vec2Wrapper::num_active_batches,
"Number of batches waiting to be processed or currently processed.")

.def("encode", &Wav2Vec2Wrapper::encode,
py::arg("features"),
py::arg("to_cpu")=false,
py::call_guard<py::gil_scoped_release>(),
R"pbdoc(
Encodes the input features.
Arguments:
features: Mel spectogram of the audio, as a float array with shape
``[batch_size, 80, 3000]``.
to_cpu: Copy the encoder output to the CPU before returning the value.
Returns:
The encoder output.
)pbdoc")

;
}

}
}
44 changes: 44 additions & 0 deletions python/ctranslate2/converters/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
common_spec,
model_spec,
transformer_spec,
wav2vec2_spec,
whisper_spec,
)

Expand Down Expand Up @@ -937,6 +938,49 @@ def set_conv1d(self, spec, module):
spec.bias = module.bias


@register_loader("Wav2Vec2Config")
class Wav2Vec2Loader(BartLoader):
@property
def architecture_name(self):
return "Wav2Vec2ForCTC"

def get_model_spec(self, model):
# Wav2Vec2 encoder Wav2Vec2PositionalConvEmbedding conv1d has groups 16
# that doesn't look available here so we make Wav2Vec2 encoder layers only
spec = wav2vec2_spec.Wav2Vec2Spec(
model.wav2vec2.encoder.config.num_hidden_layers,
model.wav2vec2.encoder.config.num_attention_heads,
)

# layer component name matching (no duplications saving)
for layer in model.wav2vec2.encoder.layers:
layer.self_attn = layer.attention
layer.self_attn_layer_norm = layer.layer_norm
layer.activation_fn = layer.feed_forward.intermediate_act_fn
layer.fc1 = layer.feed_forward.intermediate_dense
layer.fc2 = layer.feed_forward.output_dense

self.set_encoder(spec.encoder, model.wav2vec2.encoder)
self.set_linear(spec.lm_head, model.lm_head)
# only for Wav2Vec2Spec.get_vocabulary_size()
return spec

def set_config(self, config, model, tokenizer):
return

def get_vocabulary(self, model, tokenizer):
return tokenizer.get_vocab()

def set_vocabulary(self, spec, tokens):
spec.register_vocabulary(tokens)

def set_encoder(self, spec, encoder):
super().set_encoder(spec, encoder)

def set_common_layers(self, spec, module):
self.set_layer_norm(spec.layer_norm, module.layer_norm)


@register_loader("T5Config")
class T5Loader(ModelLoader):
@property
Expand Down
1 change: 1 addition & 0 deletions python/ctranslate2/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

try:
from ctranslate2._ext import (
Wav2Vec2,
Whisper,
WhisperGenerationResult,
WhisperGenerationResultAsync,
Expand Down
1 change: 1 addition & 0 deletions python/ctranslate2/specs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
TransformerEncoderSpec,
TransformerSpec,
)
from ctranslate2.specs.wav2vec2_spec import Wav2Vec2Spec
from ctranslate2.specs.whisper_spec import WhisperSpec
43 changes: 43 additions & 0 deletions python/ctranslate2/specs/wav2vec2_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import List, Optional, Tuple

import numpy as np

from ctranslate2.specs import common_spec, model_spec, transformer_spec


class Wav2Vec2Config(model_spec.ModelConfig):
"""Configuration for the Wav2Vec2 model."""

def __init__(self):
return


class Wav2Vec2Spec(model_spec.LanguageModelSpec):
def __init__(self, num_layers, num_heads):
super().__init__()
self.encoder = Wav2Vec2EncoderSpec(num_layers, num_heads)
self.lm_head = common_spec.LinearSpec()

@property
def name(self):
return "Wav2Vec2Spec"

@property
def revision(self):
return 3

def get_default_config(self):
return Wav2Vec2Config()

def get_vocabulary_size(self):
return self.lm_head.weight.shape[0]


class Wav2Vec2EncoderSpec(model_spec.LayerSpec):
def __init__(self, num_layers, num_heads):
self.num_heads = np.dtype("int16").type(num_heads)
# wav2vec2.encoder modules except pos_conv_embed due to groups=16 being not supported
self.layer_norm = common_spec.LayerNormSpec()
self.layer = [
transformer_spec.TransformerEncoderLayerSpec() for _ in range(num_layers)
]
1 change: 1 addition & 0 deletions python/tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ OpenNMT-tf==2.30.*
tensorflow-cpu==2.11.*
pytest
wurlitzer==3.0.*;platform_system=='Linux'
torch
Loading

0 comments on commit 754e9dd

Please sign in to comment.