Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

runing pytorch implementation always got similar output eg.. what's' going on #155

Open
1 task done
hjc3613 opened this issue Nov 18, 2024 · 1 comment
Open
1 task done
Labels
question Further information is requested

Comments

@hjc3613
Copy link

hjc3613 commented Nov 18, 2024

Due diligence

  • I have done my due diligence in trying to find the answer myself.

Topic

The PyTorch implementation

Question

moshi allways output unexpected answers eg.. Hello, what's going on?
here is my scripts:

from huggingface_hub import hf_hub_download
import torch
import os
import librosa
import numpy as np
from tqdm import tqdm
from moshi.moshi.models import loaders, LMGen
import soundfile as sf
import numpy as np
from subprocess import call
import sphn
import sentencepiece
device = torch.device('cpu')
MODEL_PATH = '/opt/ailab_mnt1/LLM_MODELS/moshi/moshika-pytorch-bf16'
mimi_weight = os.path.join(MODEL_PATH, loaders.MIMI_NAME)
mimi = loaders.get_mimi(mimi_weight, device='cpu')
mimi.set_num_codebooks(8)  # up to 32 for mimi, but limited to 8 for moshi.
text_tokenizer = sentencepiece.SentencePieceProcessor(os.path.join(MODEL_PATH, loaders.TEXT_TOKENIZER_NAME))

moshi_weight = os.path.join(MODEL_PATH, loaders.MOSHI_NAME)
moshi = loaders.get_moshi_lm(moshi_weight, device=device)
lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7)  # this handles sampling params etc.
# 
def save_as_wav(y, sr, output_path):
    sf.write(output_path, y, sr)

def one_round_test(audio_path):
    
    wav, sample_sr = sphn.read(audio_path)
    sample_rate = mimi.sample_rate
    wav = sphn.resample(
            wav, src_sample_rate=sample_sr, dst_sample_rate=sample_rate
        )

    wav  = torch.from_numpy(wav[None, :])
    mimi.to(device)
    # wave, sample_rate = torch.randn(1, 1, 24000 * 10), mimi.sample_rate
    with torch.no_grad():
        codes = mimi.encode(wav.to(device))  # [B, K = 8, T]
        # decoded = mimi.decode(codes)
        # save_as_mp3(decoded.numpy().squeeze(), sample_rate, audio_path.replace('.mp3', '_decoded.wav'))
        # Supports streaming too.
        frame_size = int(mimi.sample_rate / mimi.frame_rate)
        all_codes = []
        # with mimi.streaming(batch_size=1):
        for offset in tqdm(range(0, wav.shape[-1], frame_size), desc='mimi encoding...'):
                frame = wav[:, :, offset: offset + frame_size]
                if frame.shape[-1] < frame_size:
                    continue
                codes = mimi.encode(frame.to(device))
                assert codes.shape[-1] == 1, codes.shape
                all_codes.append(codes)

    # mimi.cuda()
    # moshi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MOSHI_NAME)
    
    out_wav_chunks = []
    main_text = []
    # Now we will stream over both Moshi I/O, and decode on the fly with Mimi.
    with torch.no_grad(), lm_gen.streaming(1), mimi.streaming(1):
    # with torch.no_grad():
        for idx, code in tqdm(enumerate(all_codes), desc='lm_gen steping...', total=len(all_codes)):
            tokens_out = lm_gen.step(code.to(device))
            # tokens_out is [B, 1 + 8, 1], with tokens_out[:, 1] representing the text token.
            if tokens_out is not None:
                wav_chunk = mimi.decode(tokens_out[:, 1:])
                out_wav_chunks.append(wav_chunk)
                text_token = tokens_out[:, 0, 0][0].item()
                if text_token not in (0, 3):
                    _text = text_tokenizer.id_to_piece(text_token)
                    _text = _text.replace("▁", " ")
                    main_text.append(_text)
            # print(idx, end='\r')
    out_wav = torch.cat(out_wav_chunks, dim=-1)
    save_as_wav(out_wav.squeeze().cpu().numpy(), sample_rate, audio_path.replace('.wav', '_answer.wav'))
    print('generated_text:')
    print(''.join(main_text))

if __name__ == '__main__':
    wave_root = 'wave_data/tts_res'
    wave_files = os.listdir(wave_root)
    for file in wave_files:
        audio_path = os.path.join(wave_root, file)
        one_round_test(audio_path)

my audio files are generated using tts with the following questions:

questions = [
    'At what temperature does water boil?',
    'What is the largest organ in the human body?',
    'Which is the largest planet in the solar system?',
    'What is the approximate speed of light?',
    'Who discovered the double helix structure of DNA?',
    'What is the deepest ocean trench on Earth?',
    'What is the normal human body temperature?',
    'What is the first element in the periodic table?',
    'n which year did humans first land on the moon?',
    'What is the approximate total length of the Great Wall?',
    'What is the highest mountain on Earth?',
    'How many years ago did dinosaurs go extinct?',
    'What is the smallest bone in the human body?',
    'What is the longest river in the world?',
    'How many times does the human heart beat per minute on average?',
]

when I using the first py file run moshi, I got this result:
image

note: I run the python script in cpu mode,but gpu mode also tested in online mode and got the very similar wav output

@hjc3613 hjc3613 added the question Further information is requested label Nov 18, 2024
@LaurentMazare
Copy link
Member

The released weights for moshiko/moshika are trained as a voice assistant, and as such inference usually starts by the model greeting the user with something like "hello, what's going on?", so I guess that's the expected behavior for these weights.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants