forked from OpenNMT/CTranslate2
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Revert "Revert "Support Transformers in the Wav2Vec2 Encoder for the …
…ASR Inference (OpenNMT#1520)"" This reverts commit 7c60769.
- Loading branch information
Valentin Berkes
committed
Nov 7, 2023
1 parent
7c60769
commit 754e9dd
Showing
16 changed files
with
618 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
#pragma once | ||
|
||
#include "ctranslate2/layers/transformer.h" | ||
|
||
namespace ctranslate2 { | ||
namespace layers { | ||
|
||
class Wav2Vec2Encoder : public Layer { | ||
public: | ||
Wav2Vec2Encoder(const models::Model& model, const std::string& scope); | ||
|
||
void operator()(const StorageView& features, StorageView& output); | ||
|
||
DataType output_type() const override { | ||
return _output_norm.output_type(); | ||
} | ||
|
||
dim_t output_size() const override { | ||
return _output_norm.output_size(); | ||
} | ||
|
||
dim_t input_size() const { | ||
return 1024; | ||
} | ||
|
||
bool is_encoded(const StorageView& features) const { | ||
// Input features shape: [batch_size, input_size, input_time] | ||
// Encoder output shape: [batch_size, input_time // 2, output_size] | ||
// | ||
// input_time is variable so we check that dimension 1 is different than its original value. | ||
|
||
return (features.rank() == 3 | ||
&& features.dim(2) == output_size() | ||
&& features.dim(1) != input_size()); | ||
} | ||
|
||
private: | ||
const ops::GELU _gelu; | ||
// wav2vec2.encoder modules except pos_conv_embed due to groups=16 being not supported | ||
//const ops::Transpose _transpose; | ||
const dim_t _num_heads; | ||
const std::vector<std::unique_ptr<const TransformerEncoderLayer>> _layers; | ||
const LayerNorm _output_norm; | ||
}; | ||
|
||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
#pragma once | ||
|
||
//#include "ctranslate2/generation.h" | ||
#include "ctranslate2/layers/wav2vec2.h" | ||
#include "ctranslate2/models/model.h" | ||
#include "ctranslate2/replica_pool.h" | ||
|
||
namespace ctranslate2 { | ||
namespace models { | ||
|
||
struct Wav2Vec2Options { | ||
// Maximum generation length. | ||
size_t max_length = 448; | ||
|
||
// Randomly sample from the top K candidates (set 0 to sample from the full distribution). | ||
size_t sampling_topk = 1; | ||
|
||
// Maximum index of the first predicted timestamp. | ||
size_t max_initial_timestamp_index = 50; | ||
|
||
// Suppress blank outputs at the beginning of the sampling. | ||
bool suppress_blank = true; | ||
|
||
// List of token IDs to suppress. | ||
// -1 will suppress a default set of symbols as defined in the model config.json file. | ||
std::vector<int> suppress_tokens = {-1}; | ||
}; | ||
|
||
|
||
class Wav2Vec2Model : public Model { | ||
public: | ||
const Vocabulary& get_vocabulary() const; | ||
size_t current_spec_revision() const override; | ||
bool is_quantizable(const std::string& variable_name) const override; | ||
bool is_linear_weight(const std::string& variable_name) const override; | ||
std::unique_ptr<Model> clone() const override; | ||
|
||
bool use_global_int16_scale() const override { | ||
return false; | ||
} | ||
|
||
protected: | ||
void initialize(ModelReader& model_reader) override; | ||
private: | ||
std::shared_ptr<const Vocabulary> _vocabulary; | ||
}; | ||
|
||
class Wav2Vec2Replica : public ModelReplica { | ||
public: | ||
static std::unique_ptr<Wav2Vec2Replica> create_from_model(const Model& model); | ||
|
||
Wav2Vec2Replica(const std::shared_ptr<const Wav2Vec2Model>& model); | ||
|
||
StorageView encode(StorageView features, const bool to_cpu); | ||
|
||
private: | ||
const std::shared_ptr<const Wav2Vec2Model> _model; | ||
const std::unique_ptr<layers::Wav2Vec2Encoder> _encoder; | ||
|
||
StorageView maybe_encode(StorageView features); | ||
}; | ||
|
||
class Wav2Vec2 : public ReplicaPool<Wav2Vec2Replica> { | ||
public: | ||
using ReplicaPool::ReplicaPool; | ||
|
||
std::future<StorageView> encode(const StorageView& features, const bool to_cpu); | ||
|
||
}; | ||
|
||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
#include "module.h" | ||
|
||
#include <ctranslate2/models/wav2vec2.h> | ||
|
||
#include "replica_pool.h" | ||
|
||
namespace ctranslate2 { | ||
namespace python { | ||
|
||
class Wav2Vec2Wrapper : public ReplicaPoolHelper<models::Wav2Vec2> { | ||
public: | ||
using ReplicaPoolHelper::ReplicaPoolHelper; | ||
|
||
StorageView encode(const StorageView& features, const bool to_cpu) { | ||
return _pool->encode(features, to_cpu).get(); | ||
} | ||
}; | ||
|
||
|
||
void register_wav2vec2(py::module& m) { | ||
py::class_<Wav2Vec2Wrapper>( | ||
m, "Wav2Vec2", | ||
R"pbdoc( | ||
Implements the Wav2Vec2 speech recognition model published by Facebook. | ||
See Also: | ||
https://github.com/facebookresearch/fairseq/tree/main/examples/wav2vec | ||
)pbdoc") | ||
|
||
.def(py::init<const std::string&, const std::string&, const std::variant<int, std::vector<int>>&, const StringOrMap&, size_t, size_t, long, py::object>(), | ||
py::arg("model_path"), | ||
py::arg("device")="cpu", | ||
py::kw_only(), | ||
py::arg("device_index")=0, | ||
py::arg("compute_type")="default", | ||
py::arg("inter_threads")=1, | ||
py::arg("intra_threads")=0, | ||
py::arg("max_queued_batches")=0, | ||
py::arg("files")=py::none(), | ||
R"pbdoc( | ||
Initializes a Wav2Vec2 model from a converted model. | ||
Arguments: | ||
model_path: Path to the CTranslate2 model directory. | ||
device: Device to use (possible values are: cpu, cuda, auto). | ||
device_index: Device IDs where to place this model on. | ||
compute_type: Model computation type or a dictionary mapping a device name | ||
to the computation type (possible values are: default, auto, int8, int8_float32, | ||
int8_float16, int8_bfloat16, int16, float16, bfloat16, float32). | ||
inter_threads: Number of workers to allow executing multiple batches in parallel. | ||
intra_threads: Number of OpenMP threads per worker (0 to use a default value). | ||
max_queued_batches: Maximum numbers of batches in the worker queue (-1 for unlimited, | ||
0 for an automatic value). When the queue is full, future requests will block | ||
until a free slot is available. | ||
files: Load model files from the memory. This argument is a dictionary mapping | ||
file names to file contents as file-like or bytes objects. If this is set, | ||
:obj:`model_path` acts as an identifier for this model. | ||
)pbdoc") | ||
|
||
.def_property_readonly("device", &Wav2Vec2Wrapper::device, | ||
"Device this model is running on.") | ||
.def_property_readonly("device_index", &Wav2Vec2Wrapper::device_index, | ||
"List of device IDs where this model is running on.") | ||
.def_property_readonly("compute_type", &Wav2Vec2Wrapper::compute_type, | ||
"Computation type used by the model.") | ||
.def_property_readonly("num_workers", &Wav2Vec2Wrapper::num_replicas, | ||
"Number of model workers backing this instance.") | ||
.def_property_readonly("num_queued_batches", &Wav2Vec2Wrapper::num_queued_batches, | ||
"Number of batches waiting to be processed.") | ||
.def_property_readonly("num_active_batches", &Wav2Vec2Wrapper::num_active_batches, | ||
"Number of batches waiting to be processed or currently processed.") | ||
|
||
.def("encode", &Wav2Vec2Wrapper::encode, | ||
py::arg("features"), | ||
py::arg("to_cpu")=false, | ||
py::call_guard<py::gil_scoped_release>(), | ||
R"pbdoc( | ||
Encodes the input features. | ||
Arguments: | ||
features: Mel spectogram of the audio, as a float array with shape | ||
``[batch_size, 80, 3000]``. | ||
to_cpu: Copy the encoder output to the CPU before returning the value. | ||
Returns: | ||
The encoder output. | ||
)pbdoc") | ||
|
||
; | ||
} | ||
|
||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from typing import List, Optional, Tuple | ||
|
||
import numpy as np | ||
|
||
from ctranslate2.specs import common_spec, model_spec, transformer_spec | ||
|
||
|
||
class Wav2Vec2Config(model_spec.ModelConfig): | ||
"""Configuration for the Wav2Vec2 model.""" | ||
|
||
def __init__(self): | ||
return | ||
|
||
|
||
class Wav2Vec2Spec(model_spec.LanguageModelSpec): | ||
def __init__(self, num_layers, num_heads): | ||
super().__init__() | ||
self.encoder = Wav2Vec2EncoderSpec(num_layers, num_heads) | ||
self.lm_head = common_spec.LinearSpec() | ||
|
||
@property | ||
def name(self): | ||
return "Wav2Vec2Spec" | ||
|
||
@property | ||
def revision(self): | ||
return 3 | ||
|
||
def get_default_config(self): | ||
return Wav2Vec2Config() | ||
|
||
def get_vocabulary_size(self): | ||
return self.lm_head.weight.shape[0] | ||
|
||
|
||
class Wav2Vec2EncoderSpec(model_spec.LayerSpec): | ||
def __init__(self, num_layers, num_heads): | ||
self.num_heads = np.dtype("int16").type(num_heads) | ||
# wav2vec2.encoder modules except pos_conv_embed due to groups=16 being not supported | ||
self.layer_norm = common_spec.LayerNormSpec() | ||
self.layer = [ | ||
transformer_spec.TransformerEncoderLayerSpec() for _ in range(num_layers) | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,3 +5,4 @@ OpenNMT-tf==2.30.* | |
tensorflow-cpu==2.11.* | ||
pytest | ||
wurlitzer==3.0.*;platform_system=='Linux' | ||
torch |
Oops, something went wrong.