diff --git a/CMakeLists.txt b/CMakeLists.txt index 79717045b..ce8b3d31f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -116,6 +116,7 @@ set(SOURCES src/layers/common.cc src/layers/decoder.cc src/layers/transformer.cc + src/layers/wav2vec2.cc src/layers/whisper.cc src/logging.cc src/models/language_model.cc @@ -124,6 +125,7 @@ set(SOURCES src/models/model_reader.cc src/models/sequence_to_sequence.cc src/models/transformer.cc + src/models/wav2vec2.cc src/models/whisper.cc src/ops/activation.cc src/ops/add.cc diff --git a/include/ctranslate2/layers/wav2vec2.h b/include/ctranslate2/layers/wav2vec2.h new file mode 100644 index 000000000..4c25c941a --- /dev/null +++ b/include/ctranslate2/layers/wav2vec2.h @@ -0,0 +1,47 @@ +#pragma once + +#include "ctranslate2/layers/transformer.h" + +namespace ctranslate2 { + namespace layers { + + class Wav2Vec2Encoder : public Layer { + public: + Wav2Vec2Encoder(const models::Model& model, const std::string& scope); + + void operator()(const StorageView& features, StorageView& output); + + DataType output_type() const override { + return _output_norm.output_type(); + } + + dim_t output_size() const override { + return _output_norm.output_size(); + } + + dim_t input_size() const { + return 1024; + } + + bool is_encoded(const StorageView& features) const { + // Input features shape: [batch_size, input_size, input_time] + // Encoder output shape: [batch_size, input_time // 2, output_size] + // + // input_time is variable so we check that dimension 1 is different than its original value. + + return (features.rank() == 3 + && features.dim(2) == output_size() + && features.dim(1) != input_size()); + } + + private: + const ops::GELU _gelu; + // wav2vec2.encoder modules except pos_conv_embed due to groups=16 being not supported + //const ops::Transpose _transpose; + const dim_t _num_heads; + const std::vector> _layers; + const LayerNorm _output_norm; + }; + + } +} diff --git a/include/ctranslate2/models/wav2vec2.h b/include/ctranslate2/models/wav2vec2.h new file mode 100644 index 000000000..d1034ef88 --- /dev/null +++ b/include/ctranslate2/models/wav2vec2.h @@ -0,0 +1,72 @@ +#pragma once + +//#include "ctranslate2/generation.h" +#include "ctranslate2/layers/wav2vec2.h" +#include "ctranslate2/models/model.h" +#include "ctranslate2/replica_pool.h" + +namespace ctranslate2 { + namespace models { + + struct Wav2Vec2Options { + // Maximum generation length. + size_t max_length = 448; + + // Randomly sample from the top K candidates (set 0 to sample from the full distribution). + size_t sampling_topk = 1; + + // Maximum index of the first predicted timestamp. + size_t max_initial_timestamp_index = 50; + + // Suppress blank outputs at the beginning of the sampling. + bool suppress_blank = true; + + // List of token IDs to suppress. + // -1 will suppress a default set of symbols as defined in the model config.json file. + std::vector suppress_tokens = {-1}; + }; + + + class Wav2Vec2Model : public Model { + public: + const Vocabulary& get_vocabulary() const; + size_t current_spec_revision() const override; + bool is_quantizable(const std::string& variable_name) const override; + bool is_linear_weight(const std::string& variable_name) const override; + std::unique_ptr clone() const override; + + bool use_global_int16_scale() const override { + return false; + } + + protected: + void initialize(ModelReader& model_reader) override; + private: + std::shared_ptr _vocabulary; + }; + + class Wav2Vec2Replica : public ModelReplica { + public: + static std::unique_ptr create_from_model(const Model& model); + + Wav2Vec2Replica(const std::shared_ptr& model); + + StorageView encode(StorageView features, const bool to_cpu); + + private: + const std::shared_ptr _model; + const std::unique_ptr _encoder; + + StorageView maybe_encode(StorageView features); + }; + + class Wav2Vec2 : public ReplicaPool { + public: + using ReplicaPool::ReplicaPool; + + std::future encode(const StorageView& features, const bool to_cpu); + + }; + + } +} diff --git a/python/cpp/module.cc b/python/cpp/module.cc index 997414989..4a9e47561 100644 --- a/python/cpp/module.cc +++ b/python/cpp/module.cc @@ -86,4 +86,5 @@ PYBIND11_MODULE(_ext, m) ctranslate2::python::register_generator(m); ctranslate2::python::register_encoder(m); ctranslate2::python::register_whisper(m); + ctranslate2::python::register_wav2vec2(m); } diff --git a/python/cpp/module.h b/python/cpp/module.h index b314969c4..01fdbdf59 100644 --- a/python/cpp/module.h +++ b/python/cpp/module.h @@ -17,6 +17,7 @@ namespace ctranslate2 { void register_translation_stats(py::module& m); void register_translator(py::module& m); void register_whisper(py::module& m); + void register_wav2vec2(py::module& m); } } diff --git a/python/cpp/wav2vec2.cc b/python/cpp/wav2vec2.cc new file mode 100644 index 000000000..ced116cb4 --- /dev/null +++ b/python/cpp/wav2vec2.cc @@ -0,0 +1,93 @@ +#include "module.h" + +#include + +#include "replica_pool.h" + +namespace ctranslate2 { + namespace python { + + class Wav2Vec2Wrapper : public ReplicaPoolHelper { + public: + using ReplicaPoolHelper::ReplicaPoolHelper; + + StorageView encode(const StorageView& features, const bool to_cpu) { + return _pool->encode(features, to_cpu).get(); + } + }; + + + void register_wav2vec2(py::module& m) { + py::class_( + m, "Wav2Vec2", + R"pbdoc( + Implements the Wav2Vec2 speech recognition model published by Facebook. + + See Also: + https://github.com/facebookresearch/fairseq/tree/main/examples/wav2vec + )pbdoc") + + .def(py::init>&, const StringOrMap&, size_t, size_t, long, py::object>(), + py::arg("model_path"), + py::arg("device")="cpu", + py::kw_only(), + py::arg("device_index")=0, + py::arg("compute_type")="default", + py::arg("inter_threads")=1, + py::arg("intra_threads")=0, + py::arg("max_queued_batches")=0, + py::arg("files")=py::none(), + R"pbdoc( + Initializes a Wav2Vec2 model from a converted model. + + Arguments: + model_path: Path to the CTranslate2 model directory. + device: Device to use (possible values are: cpu, cuda, auto). + device_index: Device IDs where to place this model on. + compute_type: Model computation type or a dictionary mapping a device name + to the computation type (possible values are: default, auto, int8, int8_float32, + int8_float16, int8_bfloat16, int16, float16, bfloat16, float32). + inter_threads: Number of workers to allow executing multiple batches in parallel. + intra_threads: Number of OpenMP threads per worker (0 to use a default value). + max_queued_batches: Maximum numbers of batches in the worker queue (-1 for unlimited, + 0 for an automatic value). When the queue is full, future requests will block + until a free slot is available. + files: Load model files from the memory. This argument is a dictionary mapping + file names to file contents as file-like or bytes objects. If this is set, + :obj:`model_path` acts as an identifier for this model. + )pbdoc") + + .def_property_readonly("device", &Wav2Vec2Wrapper::device, + "Device this model is running on.") + .def_property_readonly("device_index", &Wav2Vec2Wrapper::device_index, + "List of device IDs where this model is running on.") + .def_property_readonly("compute_type", &Wav2Vec2Wrapper::compute_type, + "Computation type used by the model.") + .def_property_readonly("num_workers", &Wav2Vec2Wrapper::num_replicas, + "Number of model workers backing this instance.") + .def_property_readonly("num_queued_batches", &Wav2Vec2Wrapper::num_queued_batches, + "Number of batches waiting to be processed.") + .def_property_readonly("num_active_batches", &Wav2Vec2Wrapper::num_active_batches, + "Number of batches waiting to be processed or currently processed.") + + .def("encode", &Wav2Vec2Wrapper::encode, + py::arg("features"), + py::arg("to_cpu")=false, + py::call_guard(), + R"pbdoc( + Encodes the input features. + + Arguments: + features: Mel spectogram of the audio, as a float array with shape + ``[batch_size, 80, 3000]``. + to_cpu: Copy the encoder output to the CPU before returning the value. + + Returns: + The encoder output. + )pbdoc") + + ; + } + + } +} diff --git a/python/ctranslate2/converters/transformers.py b/python/ctranslate2/converters/transformers.py index af5888648..a445fe9c1 100644 --- a/python/ctranslate2/converters/transformers.py +++ b/python/ctranslate2/converters/transformers.py @@ -22,6 +22,7 @@ common_spec, model_spec, transformer_spec, + wav2vec2_spec, whisper_spec, ) @@ -937,6 +938,49 @@ def set_conv1d(self, spec, module): spec.bias = module.bias +@register_loader("Wav2Vec2Config") +class Wav2Vec2Loader(BartLoader): + @property + def architecture_name(self): + return "Wav2Vec2ForCTC" + + def get_model_spec(self, model): + # Wav2Vec2 encoder Wav2Vec2PositionalConvEmbedding conv1d has groups 16 + # that doesn't look available here so we make Wav2Vec2 encoder layers only + spec = wav2vec2_spec.Wav2Vec2Spec( + model.wav2vec2.encoder.config.num_hidden_layers, + model.wav2vec2.encoder.config.num_attention_heads, + ) + + # layer component name matching (no duplications saving) + for layer in model.wav2vec2.encoder.layers: + layer.self_attn = layer.attention + layer.self_attn_layer_norm = layer.layer_norm + layer.activation_fn = layer.feed_forward.intermediate_act_fn + layer.fc1 = layer.feed_forward.intermediate_dense + layer.fc2 = layer.feed_forward.output_dense + + self.set_encoder(spec.encoder, model.wav2vec2.encoder) + self.set_linear(spec.lm_head, model.lm_head) + # only for Wav2Vec2Spec.get_vocabulary_size() + return spec + + def set_config(self, config, model, tokenizer): + return + + def get_vocabulary(self, model, tokenizer): + return tokenizer.get_vocab() + + def set_vocabulary(self, spec, tokens): + spec.register_vocabulary(tokens) + + def set_encoder(self, spec, encoder): + super().set_encoder(spec, encoder) + + def set_common_layers(self, spec, module): + self.set_layer_norm(spec.layer_norm, module.layer_norm) + + @register_loader("T5Config") class T5Loader(ModelLoader): @property diff --git a/python/ctranslate2/models/__init__.py b/python/ctranslate2/models/__init__.py index 067a32d8c..aba612a5c 100644 --- a/python/ctranslate2/models/__init__.py +++ b/python/ctranslate2/models/__init__.py @@ -4,6 +4,7 @@ try: from ctranslate2._ext import ( + Wav2Vec2, Whisper, WhisperGenerationResult, WhisperGenerationResultAsync, diff --git a/python/ctranslate2/specs/__init__.py b/python/ctranslate2/specs/__init__.py index 4a2bf41a1..22552f5c9 100644 --- a/python/ctranslate2/specs/__init__.py +++ b/python/ctranslate2/specs/__init__.py @@ -13,4 +13,5 @@ TransformerEncoderSpec, TransformerSpec, ) +from ctranslate2.specs.wav2vec2_spec import Wav2Vec2Spec from ctranslate2.specs.whisper_spec import WhisperSpec diff --git a/python/ctranslate2/specs/wav2vec2_spec.py b/python/ctranslate2/specs/wav2vec2_spec.py new file mode 100644 index 000000000..78b2ffa84 --- /dev/null +++ b/python/ctranslate2/specs/wav2vec2_spec.py @@ -0,0 +1,43 @@ +from typing import List, Optional, Tuple + +import numpy as np + +from ctranslate2.specs import common_spec, model_spec, transformer_spec + + +class Wav2Vec2Config(model_spec.ModelConfig): + """Configuration for the Wav2Vec2 model.""" + + def __init__(self): + return + + +class Wav2Vec2Spec(model_spec.LanguageModelSpec): + def __init__(self, num_layers, num_heads): + super().__init__() + self.encoder = Wav2Vec2EncoderSpec(num_layers, num_heads) + self.lm_head = common_spec.LinearSpec() + + @property + def name(self): + return "Wav2Vec2Spec" + + @property + def revision(self): + return 3 + + def get_default_config(self): + return Wav2Vec2Config() + + def get_vocabulary_size(self): + return self.lm_head.weight.shape[0] + + +class Wav2Vec2EncoderSpec(model_spec.LayerSpec): + def __init__(self, num_layers, num_heads): + self.num_heads = np.dtype("int16").type(num_heads) + # wav2vec2.encoder modules except pos_conv_embed due to groups=16 being not supported + self.layer_norm = common_spec.LayerNormSpec() + self.layer = [ + transformer_spec.TransformerEncoderLayerSpec() for _ in range(num_layers) + ] diff --git a/python/tests/requirements.txt b/python/tests/requirements.txt index 71c3382a6..f9cc04edf 100644 --- a/python/tests/requirements.txt +++ b/python/tests/requirements.txt @@ -5,3 +5,4 @@ OpenNMT-tf==2.30.* tensorflow-cpu==2.11.* pytest wurlitzer==3.0.*;platform_system=='Linux' +torch diff --git a/python/tests/test_transformers.py b/python/tests/test_transformers.py index a34b752a8..d85299838 100644 --- a/python/tests/test_transformers.py +++ b/python/tests/test_transformers.py @@ -943,3 +943,134 @@ def test_transformers_whisper_include_tokenizer_json(self, tmp_dir): output_dir = str(tmp_dir.join("ctranslate2_model")) output_dir = converter.convert(output_dir) assert os.path.isfile(os.path.join(output_dir, "tokenizer.json")) + + +class TestWav2Vec2: + @classmethod + def teardown_class(cls): + clear_transformers_cache_in_ci() + + @test_utils.only_on_linux + @test_utils.on_available_devices + @pytest.mark.parametrize( + "model_name,expected_transcription", + [ + ( + "facebook/wav2vec2-large-robust-ft-swbd-300h", + [ + "MISTER QUILTER IS THE APOSSEL OF THE MIDDLE CLASSES AND" + " WE ARE GLAD TO WELCOME HIS GOSPEL", + ], + ), + ], + ) + def test_transformers_wav2vec2( + self, + tmp_dir, + device, + model_name, + expected_transcription, + ): + import torch + import transformers + + converter = ctranslate2.converters.TransformersConverter( + model_name, load_as_float16="int8" + ) + output_dir = str(tmp_dir.join("ctranslate2_model")) + output_dir = converter.convert(output_dir) + # 24 x Wav2Vec2EncoderLayerStableLayerNorm converted & saved + + w2v2_model = transformers.Wav2Vec2ForCTC.from_pretrained(model_name) + del w2v2_model.wav2vec2.encoder.layers + del w2v2_model.wav2vec2.encoder.layer_norm + torch.save(w2v2_model, output_dir + "/wav2vec2_partial.bin") + w2v2_processor = transformers.Wav2Vec2Processor.from_pretrained(model_name) + torch.save(w2v2_processor, output_dir + "/wav2vec2_processor.bin") + + device = "cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu" + cpu_threads = int(os.environ.get("OMP_NUM_THREADS", 0)) + w2v2_model = torch.load(output_dir + "/wav2vec2_partial.bin").to(device) + w2v2_processor = torch.load(output_dir + "/wav2vec2_processor.bin") + ct2_w2v2_model = ctranslate2.models.Wav2Vec2( + output_dir, + device=device, + device_index=[0], + compute_type="int8", + intra_threads=cpu_threads, + inter_threads=1, + ) + + speech_array = np.load( + os.path.join(test_utils.get_data_dir(), "audio", "mr_quilter.npy") + ) + input_values = w2v2_processor( + speech_array, + padding=True, + return_tensors="pt", + sampling_rate=16000, + ).input_values + + with torch.no_grad(): + extract_features = w2v2_model.wav2vec2.feature_extractor( + input_values.to(w2v2_model.device) + ).transpose(1, 2) + hidden_states, extract_features = w2v2_model.wav2vec2.feature_projection( + extract_features + ) + position_embeddings = w2v2_model.wav2vec2.encoder.pos_conv_embed( + hidden_states + ) + hidden_states = position_embeddings + hidden_states + # hidden_states = w2v2_model.encoder.dropout(hidden_states) + # Dropout(p=0.0, inplace=False) bypassed + + if ct2_w2v2_model.device == "cuda": + hidden_states = hidden_states.cpu() + else: + hidden_states.numpy() + + hidden_states = np.ascontiguousarray(hidden_states) + hidden_states = ctranslate2.StorageView.from_array(hidden_states) + to_cpu = ( + ct2_w2v2_model.device == "cuda" and len(ct2_w2v2_model.device_index) > 1 + ) + ct2_output = ct2_w2v2_model.encode( + hidden_states, + to_cpu=to_cpu, + ) # 24 x Wav2Vec2EncoderLayerStableLayerNorm processed + if ct2_w2v2_model.device == "cuda": + hidden_states = torch.as_tensor( + ct2_output, + device=ct2_w2v2_model.device, + ) + else: + hidden_states = torch.as_tensor( + np.array(ct2_output), + dtype=torch.float32, + device=ct2_w2v2_model.device, + ) + + encoder_outputs = transformers.modeling_outputs.BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=None, + attentions=None, + ) + hidden_states = encoder_outputs[0] + outputs = transformers.modeling_outputs.Wav2Vec2BaseModelOutput( + last_hidden_state=hidden_states, + extract_features=extract_features, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + hidden_states = outputs[0] + # hidden_states = w2v2_model.dropout(hidden_states) + # Dropout(p=0.0, inplace=False) bypassed + + with torch.no_grad(): + logits = w2v2_model.lm_head(hidden_states.to(torch.float32))[0] + + predicted_ids = torch.argmax(logits, dim=-1) + transcription = w2v2_processor.decode(predicted_ids, output_word_offsets=True) + + assert transcription[0] == expected_transcription[0] diff --git a/python/tools/prepare_build_environment_linux.sh b/python/tools/prepare_build_environment_linux.sh index f1416295b..a4009e566 100755 --- a/python/tools/prepare_build_environment_linux.sh +++ b/python/tools/prepare_build_environment_linux.sh @@ -32,7 +32,7 @@ else libcublas-devel-11-2-11.4.1.1043-1 ln -s cuda-11.2 /usr/local/cuda - ONEAPI_VERSION=2023.0.0 + ONEAPI_VERSION=2023.2.0 yum-config-manager --add-repo https://yum.repos.intel.com/oneapi rpm --import https://yum.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB yum install -y intel-oneapi-mkl-devel-$ONEAPI_VERSION diff --git a/src/layers/wav2vec2.cc b/src/layers/wav2vec2.cc new file mode 100644 index 000000000..237c77fad --- /dev/null +++ b/src/layers/wav2vec2.cc @@ -0,0 +1,58 @@ +#include "ctranslate2/layers/wav2vec2.h" + + +namespace ctranslate2 { + namespace layers { + Wav2Vec2Encoder::Wav2Vec2Encoder(const models::Model& model, const std::string& scope) + : _num_heads(model.get_attribute_with_default(scope + "/num_heads", 8)) + , _layers(build_layers_list(model, + scope + "/layer", + _num_heads, + /*pre_norm=*/true, + ops::ActivationType::GELU)) + , _output_norm(model, scope + "/layer_norm") + { + } + + void Wav2Vec2Encoder::operator()(const StorageView& features, StorageView& output) { + PROFILE("Wav2Vec2Encoder"); + + // SAD in front-end handles the input length + //const dim_t expected_depth = 1024; + //const dim_t expected_time = 406; + + if (features.rank() != 3) + throw std::invalid_argument("Expected input features to have 3 dimensions, but got " + + std::to_string(features.rank()) + + " dimension(s) instead"); + /* //may need to limit the input lenght + if (features.dim(1) != expected_depth || features.dim(2) != expected_time) + throw std::invalid_argument("Invalid input features shape: expected an input with shape (" + + std::to_string(features.dim(0)) + + ", " + + std::to_string(expected_depth) + + ", " + + std::to_string(expected_time) + + "), but got an input with shape (" + + std::to_string(features.dim(0)) + + ", " + + std::to_string(features.dim(1)) + + ", " + + std::to_string(features.dim(2)) + + ") instead;; _conv1.output_size() " + + std::to_string(_conv1.output_size())); + //+ ") instead"); + */ + + StorageView input(output_type(), features.device()); + input = features; + for (const auto& layer : _layers) { + (*layer)(input, nullptr, output); + input = std::move(output); + } + + _output_norm(input, output); + } + + } +} diff --git a/src/models/model_factory.cc b/src/models/model_factory.cc index e5a904aff..488e0b8b2 100644 --- a/src/models/model_factory.cc +++ b/src/models/model_factory.cc @@ -3,6 +3,7 @@ #include #include "ctranslate2/models/whisper.h" +#include "ctranslate2/models/wav2vec2.h" #include "ctranslate2/models/transformer.h" namespace ctranslate2 { @@ -20,6 +21,8 @@ namespace ctranslate2 { register_model("TransformerEncoderSpec"); register_model("WhisperSpec"); + + register_model("Wav2Vec2Spec"); } std::shared_ptr create_model(const std::string& name) { diff --git a/src/models/wav2vec2.cc b/src/models/wav2vec2.cc new file mode 100644 index 000000000..79a7a40d4 --- /dev/null +++ b/src/models/wav2vec2.cc @@ -0,0 +1,119 @@ +#include "ctranslate2/models/wav2vec2.h" + +#include + +#include "ctranslate2/decoding.h" + +#include "dispatch.h" +#include "dtw.h" + +#ifdef CT2_WITH_CUDA +# include "cuda/utils.h" +#endif + + +namespace ctranslate2 { + namespace models { + + const Vocabulary& Wav2Vec2Model::get_vocabulary() const { + return *_vocabulary; + } + + size_t Wav2Vec2Model::current_spec_revision() const { + return 3; + } + + void Wav2Vec2Model::initialize(ModelReader& model_reader) { + VocabularyInfo vocab_info; + vocab_info.unk_token = "[UNK]"; + vocab_info.bos_token = ""; + vocab_info.eos_token = ""; + + _vocabulary = load_vocabulary(model_reader, "vocabulary", std::move(vocab_info)); + if (!_vocabulary) + throw std::runtime_error("Cannot load the vocabulary from the model directory"); + } + + bool Wav2Vec2Model::is_quantizable(const std::string& variable_name) const { + return (Model::is_quantizable(variable_name) + && variable_name.find("conv") == std::string::npos); + } + + bool Wav2Vec2Model::is_linear_weight(const std::string& variable_name) const { + return is_quantizable(variable_name) && variable_name.find("embeddings") == std::string::npos; + } + + std::unique_ptr Wav2Vec2Model::clone() const { + return std::make_unique(*this); + } + + + std::unique_ptr Wav2Vec2Replica::create_from_model(const Model& model) { + if (!dynamic_cast(&model)) + throw std::invalid_argument("The model is not a Wav2Vec2 model"); + + const auto scoped_device_setter = model.get_scoped_device_setter(); + const auto model_ptr = model.shared_from_this(); + const auto concrete_model = std::static_pointer_cast(model_ptr); + return std::make_unique(concrete_model); + } + + Wav2Vec2Replica::Wav2Vec2Replica(const std::shared_ptr& model) + : ModelReplica(model) + , _model(model) + , _encoder(std::make_unique(*model, "encoder")) + { + } + + + StorageView Wav2Vec2Replica::encode(StorageView features, const bool to_cpu) { + PROFILE("Wav2Vec2Replica::encode"); + +#ifdef CT2_WITH_CUDA + const cuda::UseTrueFp16GemmInScope use_true_fp16_gemm(false); +#endif + + const auto scoped_device_setter = _model->get_scoped_device_setter(); + const Device device = _model->device(); + const DataType dtype = _encoder->output_type(); + features.move_to(device, dtype); + + StorageView encoder_output(dtype, device); + (*_encoder)(features, encoder_output); + + if (to_cpu) { + if (device != Device::CPU) + encoder_output = encoder_output.to(Device::CPU); + + return encoder_output; + } + + // Ensure all operations are finished before returning the output. + synchronize_stream(device); + + return encoder_output; + } + + StorageView Wav2Vec2Replica::maybe_encode(StorageView features) { + const Device device = _model->device(); + const DataType dtype = _encoder->output_type(); + + features.move_to(device, dtype); + + if (_encoder->is_encoded(features)) + return features; + + StorageView encoder_output(dtype, device); + (*_encoder)(features, encoder_output); + return encoder_output; + } + + std::future Wav2Vec2::encode(const StorageView& features, const bool to_cpu) { + return post( + [features = features.sync_copy(), to_cpu](Wav2Vec2Replica& replica) mutable { + return replica.encode(std::move(features), to_cpu); + }); + } + + } +}